zeta2.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use chrono::TimeDelta;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
   7    ZED_VERSION_HEADER_NAME,
   8};
   9use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, PlannedPrompt};
  10use edit_prediction_context::{
  11    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  12    EditPredictionExcerptOptions, EditPredictionScoreOptions, SyntaxIndex, SyntaxIndexState,
  13};
  14use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  15use futures::AsyncReadExt as _;
  16use futures::channel::{mpsc, oneshot};
  17use gpui::http_client::{AsyncBody, Method};
  18use gpui::{
  19    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  20    http_client, prelude::*,
  21};
  22use language::BufferSnapshot;
  23use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
  24use language_model::{LlmApiToken, RefreshLlmTokenListener};
  25use project::Project;
  26use release_channel::AppVersion;
  27use serde::de::DeserializeOwned;
  28use std::collections::{HashMap, VecDeque, hash_map};
  29use std::path::Path;
  30use std::str::FromStr as _;
  31use std::sync::Arc;
  32use std::time::{Duration, Instant};
  33use thiserror::Error;
  34use util::rel_path::RelPathBuf;
  35use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  36
  37mod prediction;
  38mod provider;
  39
  40use crate::prediction::EditPrediction;
  41pub use provider::ZetaEditPredictionProvider;
  42
  43const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  44
  45/// Maximum number of events to track.
  46const MAX_EVENT_COUNT: usize = 16;
  47
  48pub const DEFAULT_CONTEXT_OPTIONS: EditPredictionContextOptions = EditPredictionContextOptions {
  49    use_imports: true,
  50    max_retrieved_declarations: 0,
  51    excerpt: EditPredictionExcerptOptions {
  52        max_bytes: 512,
  53        min_bytes: 128,
  54        target_before_cursor_over_total_bytes: 0.5,
  55    },
  56    score: EditPredictionScoreOptions {
  57        omit_excerpt_overlaps: true,
  58    },
  59};
  60
  61pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  62    context: DEFAULT_CONTEXT_OPTIONS,
  63    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
  64    max_diagnostic_bytes: 2048,
  65    prompt_format: PromptFormat::DEFAULT,
  66    file_indexing_parallelism: 1,
  67};
  68
  69pub struct Zeta2FeatureFlag;
  70
  71impl FeatureFlag for Zeta2FeatureFlag {
  72    const NAME: &'static str = "zeta2";
  73
  74    fn enabled_for_staff() -> bool {
  75        false
  76    }
  77}
  78
  79#[derive(Clone)]
  80struct ZetaGlobal(Entity<Zeta>);
  81
  82impl Global for ZetaGlobal {}
  83
  84pub struct Zeta {
  85    client: Arc<Client>,
  86    user_store: Entity<UserStore>,
  87    llm_token: LlmApiToken,
  88    _llm_token_subscription: Subscription,
  89    projects: HashMap<EntityId, ZetaProject>,
  90    options: ZetaOptions,
  91    update_required: bool,
  92    debug_tx: Option<mpsc::UnboundedSender<PredictionDebugInfo>>,
  93}
  94
  95#[derive(Debug, Clone, PartialEq)]
  96pub struct ZetaOptions {
  97    pub context: EditPredictionContextOptions,
  98    pub max_prompt_bytes: usize,
  99    pub max_diagnostic_bytes: usize,
 100    pub prompt_format: predict_edits_v3::PromptFormat,
 101    pub file_indexing_parallelism: usize,
 102}
 103
 104pub struct PredictionDebugInfo {
 105    pub request: predict_edits_v3::PredictEditsRequest,
 106    pub retrieval_time: TimeDelta,
 107    pub buffer: WeakEntity<Buffer>,
 108    pub position: language::Anchor,
 109    pub local_prompt: Result<String, String>,
 110    pub response_rx: oneshot::Receiver<Result<predict_edits_v3::PredictEditsResponse, String>>,
 111}
 112
 113pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 114
 115struct ZetaProject {
 116    syntax_index: Entity<SyntaxIndex>,
 117    events: VecDeque<Event>,
 118    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 119    current_prediction: Option<CurrentEditPrediction>,
 120}
 121
 122#[derive(Debug, Clone)]
 123struct CurrentEditPrediction {
 124    pub requested_by_buffer_id: EntityId,
 125    pub prediction: EditPrediction,
 126}
 127
 128impl CurrentEditPrediction {
 129    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 130        let Some(new_edits) = self
 131            .prediction
 132            .interpolate(&self.prediction.buffer.read(cx))
 133        else {
 134            return false;
 135        };
 136
 137        if self.prediction.buffer != old_prediction.prediction.buffer {
 138            return true;
 139        }
 140
 141        let Some(old_edits) = old_prediction
 142            .prediction
 143            .interpolate(&old_prediction.prediction.buffer.read(cx))
 144        else {
 145            return true;
 146        };
 147
 148        // This reduces the occurrence of UI thrash from replacing edits
 149        //
 150        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 151        if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
 152            && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
 153            && old_edits.len() == 1
 154            && new_edits.len() == 1
 155        {
 156            let (old_range, old_text) = &old_edits[0];
 157            let (new_range, new_text) = &new_edits[0];
 158            new_range == old_range && new_text.starts_with(old_text)
 159        } else {
 160            true
 161        }
 162    }
 163}
 164
 165/// A prediction from the perspective of a buffer.
 166#[derive(Debug)]
 167enum BufferEditPrediction<'a> {
 168    Local { prediction: &'a EditPrediction },
 169    Jump { prediction: &'a EditPrediction },
 170}
 171
 172struct RegisteredBuffer {
 173    snapshot: BufferSnapshot,
 174    _subscriptions: [gpui::Subscription; 2],
 175}
 176
 177#[derive(Clone)]
 178pub enum Event {
 179    BufferChange {
 180        old_snapshot: BufferSnapshot,
 181        new_snapshot: BufferSnapshot,
 182        timestamp: Instant,
 183    },
 184}
 185
 186impl Zeta {
 187    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 188        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 189    }
 190
 191    pub fn global(
 192        client: &Arc<Client>,
 193        user_store: &Entity<UserStore>,
 194        cx: &mut App,
 195    ) -> Entity<Self> {
 196        cx.try_global::<ZetaGlobal>()
 197            .map(|global| global.0.clone())
 198            .unwrap_or_else(|| {
 199                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 200                cx.set_global(ZetaGlobal(zeta.clone()));
 201                zeta
 202            })
 203    }
 204
 205    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 206        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 207
 208        Self {
 209            projects: HashMap::new(),
 210            client,
 211            user_store,
 212            options: DEFAULT_OPTIONS,
 213            llm_token: LlmApiToken::default(),
 214            _llm_token_subscription: cx.subscribe(
 215                &refresh_llm_token_listener,
 216                |this, _listener, _event, cx| {
 217                    let client = this.client.clone();
 218                    let llm_token = this.llm_token.clone();
 219                    cx.spawn(async move |_this, _cx| {
 220                        llm_token.refresh(&client).await?;
 221                        anyhow::Ok(())
 222                    })
 223                    .detach_and_log_err(cx);
 224                },
 225            ),
 226            update_required: false,
 227            debug_tx: None,
 228        }
 229    }
 230
 231    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<PredictionDebugInfo> {
 232        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 233        self.debug_tx = Some(debug_watch_tx);
 234        debug_watch_rx
 235    }
 236
 237    pub fn options(&self) -> &ZetaOptions {
 238        &self.options
 239    }
 240
 241    pub fn set_options(&mut self, options: ZetaOptions) {
 242        self.options = options;
 243    }
 244
 245    pub fn clear_history(&mut self) {
 246        for zeta_project in self.projects.values_mut() {
 247            zeta_project.events.clear();
 248        }
 249    }
 250
 251    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 252        self.user_store.read(cx).edit_prediction_usage()
 253    }
 254
 255    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 256        self.get_or_init_zeta_project(project, cx);
 257    }
 258
 259    pub fn register_buffer(
 260        &mut self,
 261        buffer: &Entity<Buffer>,
 262        project: &Entity<Project>,
 263        cx: &mut Context<Self>,
 264    ) {
 265        let zeta_project = self.get_or_init_zeta_project(project, cx);
 266        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 267    }
 268
 269    fn get_or_init_zeta_project(
 270        &mut self,
 271        project: &Entity<Project>,
 272        cx: &mut App,
 273    ) -> &mut ZetaProject {
 274        self.projects
 275            .entry(project.entity_id())
 276            .or_insert_with(|| ZetaProject {
 277                syntax_index: cx.new(|cx| {
 278                    SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 279                }),
 280                events: VecDeque::new(),
 281                registered_buffers: HashMap::new(),
 282                current_prediction: None,
 283            })
 284    }
 285
 286    fn register_buffer_impl<'a>(
 287        zeta_project: &'a mut ZetaProject,
 288        buffer: &Entity<Buffer>,
 289        project: &Entity<Project>,
 290        cx: &mut Context<Self>,
 291    ) -> &'a mut RegisteredBuffer {
 292        let buffer_id = buffer.entity_id();
 293        match zeta_project.registered_buffers.entry(buffer_id) {
 294            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 295            hash_map::Entry::Vacant(entry) => {
 296                let snapshot = buffer.read(cx).snapshot();
 297                let project_entity_id = project.entity_id();
 298                entry.insert(RegisteredBuffer {
 299                    snapshot,
 300                    _subscriptions: [
 301                        cx.subscribe(buffer, {
 302                            let project = project.downgrade();
 303                            move |this, buffer, event, cx| {
 304                                if let language::BufferEvent::Edited = event
 305                                    && let Some(project) = project.upgrade()
 306                                {
 307                                    this.report_changes_for_buffer(&buffer, &project, cx);
 308                                }
 309                            }
 310                        }),
 311                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 312                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 313                            else {
 314                                return;
 315                            };
 316                            zeta_project.registered_buffers.remove(&buffer_id);
 317                        }),
 318                    ],
 319                })
 320            }
 321        }
 322    }
 323
 324    fn report_changes_for_buffer(
 325        &mut self,
 326        buffer: &Entity<Buffer>,
 327        project: &Entity<Project>,
 328        cx: &mut Context<Self>,
 329    ) -> BufferSnapshot {
 330        let zeta_project = self.get_or_init_zeta_project(project, cx);
 331        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 332
 333        let new_snapshot = buffer.read(cx).snapshot();
 334        if new_snapshot.version != registered_buffer.snapshot.version {
 335            let old_snapshot =
 336                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 337            Self::push_event(
 338                zeta_project,
 339                Event::BufferChange {
 340                    old_snapshot,
 341                    new_snapshot: new_snapshot.clone(),
 342                    timestamp: Instant::now(),
 343                },
 344            );
 345        }
 346
 347        new_snapshot
 348    }
 349
 350    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
 351        let events = &mut zeta_project.events;
 352
 353        if let Some(Event::BufferChange {
 354            new_snapshot: last_new_snapshot,
 355            timestamp: last_timestamp,
 356            ..
 357        }) = events.back_mut()
 358        {
 359            // Coalesce edits for the same buffer when they happen one after the other.
 360            let Event::BufferChange {
 361                old_snapshot,
 362                new_snapshot,
 363                timestamp,
 364            } = &event;
 365
 366            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 367                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 368                && old_snapshot.version == last_new_snapshot.version
 369            {
 370                *last_new_snapshot = new_snapshot.clone();
 371                *last_timestamp = *timestamp;
 372                return;
 373            }
 374        }
 375
 376        if events.len() >= MAX_EVENT_COUNT {
 377            // These are halved instead of popping to improve prompt caching.
 378            events.drain(..MAX_EVENT_COUNT / 2);
 379        }
 380
 381        events.push_back(event);
 382    }
 383
 384    fn current_prediction_for_buffer(
 385        &self,
 386        buffer: &Entity<Buffer>,
 387        project: &Entity<Project>,
 388        cx: &App,
 389    ) -> Option<BufferEditPrediction<'_>> {
 390        let project_state = self.projects.get(&project.entity_id())?;
 391
 392        let CurrentEditPrediction {
 393            requested_by_buffer_id,
 394            prediction,
 395        } = project_state.current_prediction.as_ref()?;
 396
 397        if prediction.targets_buffer(buffer.read(cx), cx) {
 398            Some(BufferEditPrediction::Local { prediction })
 399        } else if *requested_by_buffer_id == buffer.entity_id() {
 400            Some(BufferEditPrediction::Jump { prediction })
 401        } else {
 402            None
 403        }
 404    }
 405
 406    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 407        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 408            return;
 409        };
 410
 411        let Some(prediction) = project_state.current_prediction.take() else {
 412            return;
 413        };
 414        let request_id = prediction.prediction.id.into();
 415
 416        let client = self.client.clone();
 417        let llm_token = self.llm_token.clone();
 418        let app_version = AppVersion::global(cx);
 419        cx.spawn(async move |this, cx| {
 420            let url = if let Ok(predict_edits_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
 421                http_client::Url::parse(&predict_edits_url)?
 422            } else {
 423                client
 424                    .http_client()
 425                    .build_zed_llm_url("/predict_edits/accept", &[])?
 426            };
 427
 428            let response = cx
 429                .background_spawn(Self::send_api_request::<()>(
 430                    move |builder| {
 431                        let req = builder.uri(url.as_ref()).body(
 432                            serde_json::to_string(&AcceptEditPredictionBody { request_id })?.into(),
 433                        );
 434                        Ok(req?)
 435                    },
 436                    client,
 437                    llm_token,
 438                    app_version,
 439                ))
 440                .await;
 441
 442            Self::handle_api_response(&this, response, cx)?;
 443            anyhow::Ok(())
 444        })
 445        .detach_and_log_err(cx);
 446    }
 447
 448    fn discard_current_prediction(&mut self, project: &Entity<Project>) {
 449        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 450            project_state.current_prediction.take();
 451        };
 452    }
 453
 454    pub fn refresh_prediction(
 455        &mut self,
 456        project: &Entity<Project>,
 457        buffer: &Entity<Buffer>,
 458        position: language::Anchor,
 459        cx: &mut Context<Self>,
 460    ) -> Task<Result<()>> {
 461        let request_task = self.request_prediction(project, buffer, position, cx);
 462        let buffer = buffer.clone();
 463        let project = project.clone();
 464
 465        cx.spawn(async move |this, cx| {
 466            if let Some(prediction) = request_task.await? {
 467                this.update(cx, |this, cx| {
 468                    let project_state = this
 469                        .projects
 470                        .get_mut(&project.entity_id())
 471                        .context("Project not found")?;
 472
 473                    let new_prediction = CurrentEditPrediction {
 474                        requested_by_buffer_id: buffer.entity_id(),
 475                        prediction: prediction,
 476                    };
 477
 478                    if project_state
 479                        .current_prediction
 480                        .as_ref()
 481                        .is_none_or(|old_prediction| {
 482                            new_prediction.should_replace_prediction(&old_prediction, cx)
 483                        })
 484                    {
 485                        project_state.current_prediction = Some(new_prediction);
 486                    }
 487                    anyhow::Ok(())
 488                })??;
 489            }
 490            Ok(())
 491        })
 492    }
 493
 494    fn request_prediction(
 495        &mut self,
 496        project: &Entity<Project>,
 497        buffer: &Entity<Buffer>,
 498        position: language::Anchor,
 499        cx: &mut Context<Self>,
 500    ) -> Task<Result<Option<EditPrediction>>> {
 501        let project_state = self.projects.get(&project.entity_id());
 502
 503        let index_state = project_state.map(|state| {
 504            state
 505                .syntax_index
 506                .read_with(cx, |index, _cx| index.state().clone())
 507        });
 508        let options = self.options.clone();
 509        let snapshot = buffer.read(cx).snapshot();
 510        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx).into()) else {
 511            return Task::ready(Err(anyhow!("No file path for excerpt")));
 512        };
 513        let client = self.client.clone();
 514        let llm_token = self.llm_token.clone();
 515        let app_version = AppVersion::global(cx);
 516        let worktree_snapshots = project
 517            .read(cx)
 518            .worktrees(cx)
 519            .map(|worktree| worktree.read(cx).snapshot())
 520            .collect::<Vec<_>>();
 521        let debug_tx = self.debug_tx.clone();
 522
 523        let events = project_state
 524            .map(|state| {
 525                state
 526                    .events
 527                    .iter()
 528                    .filter_map(|event| match event {
 529                        Event::BufferChange {
 530                            old_snapshot,
 531                            new_snapshot,
 532                            ..
 533                        } => {
 534                            let path = new_snapshot.file().map(|f| f.full_path(cx));
 535
 536                            let old_path = old_snapshot.file().and_then(|f| {
 537                                let old_path = f.full_path(cx);
 538                                if Some(&old_path) != path.as_ref() {
 539                                    Some(old_path)
 540                                } else {
 541                                    None
 542                                }
 543                            });
 544
 545                            // TODO [zeta2] move to bg?
 546                            let diff =
 547                                language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
 548
 549                            if path == old_path && diff.is_empty() {
 550                                None
 551                            } else {
 552                                Some(predict_edits_v3::Event::BufferChange {
 553                                    old_path,
 554                                    path,
 555                                    diff,
 556                                    //todo: Actually detect if this edit was predicted or not
 557                                    predicted: false,
 558                                })
 559                            }
 560                        }
 561                    })
 562                    .collect::<Vec<_>>()
 563            })
 564            .unwrap_or_default();
 565
 566        let diagnostics = snapshot.diagnostic_sets().clone();
 567
 568        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
 569            let mut path = f.worktree.read(cx).absolutize(&f.path);
 570            if path.pop() { Some(path) } else { None }
 571        });
 572
 573        // TODO data collection
 574        let can_collect_data = cx.is_staff();
 575
 576        let request_task = cx.background_spawn({
 577            let snapshot = snapshot.clone();
 578            let buffer = buffer.clone();
 579            async move {
 580                let index_state = if let Some(index_state) = index_state {
 581                    Some(index_state.lock_owned().await)
 582                } else {
 583                    None
 584                };
 585
 586                let cursor_offset = position.to_offset(&snapshot);
 587                let cursor_point = cursor_offset.to_point(&snapshot);
 588
 589                let before_retrieval = chrono::Utc::now();
 590
 591                let Some(context) = EditPredictionContext::gather_context(
 592                    cursor_point,
 593                    &snapshot,
 594                    parent_abs_path.as_deref(),
 595                    &options.context,
 596                    index_state.as_deref(),
 597                ) else {
 598                    return Ok((None, None));
 599                };
 600
 601                let retrieval_time = chrono::Utc::now() - before_retrieval;
 602
 603                let (diagnostic_groups, diagnostic_groups_truncated) =
 604                    Self::gather_nearby_diagnostics(
 605                        cursor_offset,
 606                        &diagnostics,
 607                        &snapshot,
 608                        options.max_diagnostic_bytes,
 609                    );
 610
 611                let request = make_cloud_request(
 612                    excerpt_path,
 613                    context,
 614                    events,
 615                    can_collect_data,
 616                    diagnostic_groups,
 617                    diagnostic_groups_truncated,
 618                    None,
 619                    debug_tx.is_some(),
 620                    &worktree_snapshots,
 621                    index_state.as_deref(),
 622                    Some(options.max_prompt_bytes),
 623                    options.prompt_format,
 624                );
 625
 626                let debug_response_tx = if let Some(debug_tx) = &debug_tx {
 627                    let (response_tx, response_rx) = oneshot::channel();
 628
 629                    let local_prompt = PlannedPrompt::populate(&request)
 630                        .and_then(|p| p.to_prompt_string().map(|p| p.0))
 631                        .map_err(|err| err.to_string());
 632
 633                    debug_tx
 634                        .unbounded_send(PredictionDebugInfo {
 635                            request: request.clone(),
 636                            retrieval_time,
 637                            buffer: buffer.downgrade(),
 638                            local_prompt,
 639                            position,
 640                            response_rx,
 641                        })
 642                        .ok();
 643                    Some(response_tx)
 644                } else {
 645                    None
 646                };
 647
 648                if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
 649                    if let Some(debug_response_tx) = debug_response_tx {
 650                        debug_response_tx
 651                            .send(Err("Request skipped".to_string()))
 652                            .ok();
 653                    }
 654                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
 655                }
 656
 657                let response =
 658                    Self::send_prediction_request(client, llm_token, app_version, request).await;
 659
 660                if let Some(debug_response_tx) = debug_response_tx {
 661                    debug_response_tx
 662                        .send(
 663                            response
 664                                .as_ref()
 665                                .map_err(|err| err.to_string())
 666                                .map(|response| response.0.clone()),
 667                        )
 668                        .ok();
 669                }
 670
 671                response.map(|(res, usage)| (Some(res), usage))
 672            }
 673        });
 674
 675        let buffer = buffer.clone();
 676
 677        cx.spawn({
 678            let project = project.clone();
 679            async move |this, cx| {
 680                let Some(response) = Self::handle_api_response(&this, request_task.await, cx)?
 681                else {
 682                    return Ok(None);
 683                };
 684
 685                // TODO telemetry: duration, etc
 686                Ok(EditPrediction::from_response(response, &snapshot, &buffer, &project, cx).await)
 687            }
 688        })
 689    }
 690
 691    async fn send_prediction_request(
 692        client: Arc<Client>,
 693        llm_token: LlmApiToken,
 694        app_version: SemanticVersion,
 695        request: predict_edits_v3::PredictEditsRequest,
 696    ) -> Result<(
 697        predict_edits_v3::PredictEditsResponse,
 698        Option<EditPredictionUsage>,
 699    )> {
 700        let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 701            http_client::Url::parse(&predict_edits_url)?
 702        } else {
 703            client
 704                .http_client()
 705                .build_zed_llm_url("/predict_edits/v3", &[])?
 706        };
 707
 708        Self::send_api_request(
 709            |builder| {
 710                let req = builder
 711                    .uri(url.as_ref())
 712                    .body(serde_json::to_string(&request)?.into());
 713                Ok(req?)
 714            },
 715            client,
 716            llm_token,
 717            app_version,
 718        )
 719        .await
 720    }
 721
 722    fn handle_api_response<T>(
 723        this: &WeakEntity<Self>,
 724        response: Result<(T, Option<EditPredictionUsage>)>,
 725        cx: &mut gpui::AsyncApp,
 726    ) -> Result<T> {
 727        match response {
 728            Ok((data, usage)) => {
 729                if let Some(usage) = usage {
 730                    this.update(cx, |this, cx| {
 731                        this.user_store.update(cx, |user_store, cx| {
 732                            user_store.update_edit_prediction_usage(usage, cx);
 733                        });
 734                    })
 735                    .ok();
 736                }
 737                Ok(data)
 738            }
 739            Err(err) => {
 740                if err.is::<ZedUpdateRequiredError>() {
 741                    cx.update(|cx| {
 742                        this.update(cx, |this, _cx| {
 743                            this.update_required = true;
 744                        })
 745                        .ok();
 746
 747                        let error_message: SharedString = err.to_string().into();
 748                        show_app_notification(
 749                            NotificationId::unique::<ZedUpdateRequiredError>(),
 750                            cx,
 751                            move |cx| {
 752                                cx.new(|cx| {
 753                                    ErrorMessagePrompt::new(error_message.clone(), cx)
 754                                        .with_link_button("Update Zed", "https://zed.dev/releases")
 755                                })
 756                            },
 757                        );
 758                    })
 759                    .ok();
 760                }
 761                Err(err)
 762            }
 763        }
 764    }
 765
 766    async fn send_api_request<Res>(
 767        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
 768        client: Arc<Client>,
 769        llm_token: LlmApiToken,
 770        app_version: SemanticVersion,
 771    ) -> Result<(Res, Option<EditPredictionUsage>)>
 772    where
 773        Res: DeserializeOwned,
 774    {
 775        let http_client = client.http_client();
 776        let mut token = llm_token.acquire(&client).await?;
 777        let mut did_retry = false;
 778
 779        loop {
 780            let request_builder = http_client::Request::builder().method(Method::POST);
 781
 782            let request = build(
 783                request_builder
 784                    .header("Content-Type", "application/json")
 785                    .header("Authorization", format!("Bearer {}", token))
 786                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
 787            )?;
 788
 789            let mut response = http_client.send(request).await?;
 790
 791            if let Some(minimum_required_version) = response
 792                .headers()
 793                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 794                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 795            {
 796                anyhow::ensure!(
 797                    app_version >= minimum_required_version,
 798                    ZedUpdateRequiredError {
 799                        minimum_version: minimum_required_version
 800                    }
 801                );
 802            }
 803
 804            if response.status().is_success() {
 805                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 806
 807                let mut body = Vec::new();
 808                response.body_mut().read_to_end(&mut body).await?;
 809                return Ok((serde_json::from_slice(&body)?, usage));
 810            } else if !did_retry
 811                && response
 812                    .headers()
 813                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 814                    .is_some()
 815            {
 816                did_retry = true;
 817                token = llm_token.refresh(&client).await?;
 818            } else {
 819                let mut body = String::new();
 820                response.body_mut().read_to_string(&mut body).await?;
 821                anyhow::bail!(
 822                    "Request failed with status: {:?}\nBody: {}",
 823                    response.status(),
 824                    body
 825                );
 826            }
 827        }
 828    }
 829
 830    fn gather_nearby_diagnostics(
 831        cursor_offset: usize,
 832        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
 833        snapshot: &BufferSnapshot,
 834        max_diagnostics_bytes: usize,
 835    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
 836        // TODO: Could make this more efficient
 837        let mut diagnostic_groups = Vec::new();
 838        for (language_server_id, diagnostics) in diagnostic_sets {
 839            let mut groups = Vec::new();
 840            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
 841            diagnostic_groups.extend(
 842                groups
 843                    .into_iter()
 844                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
 845            );
 846        }
 847
 848        // sort by proximity to cursor
 849        diagnostic_groups.sort_by_key(|group| {
 850            let range = &group.entries[group.primary_ix].range;
 851            if range.start >= cursor_offset {
 852                range.start - cursor_offset
 853            } else if cursor_offset >= range.end {
 854                cursor_offset - range.end
 855            } else {
 856                (cursor_offset - range.start).min(range.end - cursor_offset)
 857            }
 858        });
 859
 860        let mut results = Vec::new();
 861        let mut diagnostic_groups_truncated = false;
 862        let mut diagnostics_byte_count = 0;
 863        for group in diagnostic_groups {
 864            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
 865            diagnostics_byte_count += raw_value.get().len();
 866            if diagnostics_byte_count > max_diagnostics_bytes {
 867                diagnostic_groups_truncated = true;
 868                break;
 869            }
 870            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
 871        }
 872
 873        (results, diagnostic_groups_truncated)
 874    }
 875
 876    // TODO: Dedupe with similar code in request_prediction?
 877    pub fn cloud_request_for_zeta_cli(
 878        &mut self,
 879        project: &Entity<Project>,
 880        buffer: &Entity<Buffer>,
 881        position: language::Anchor,
 882        cx: &mut Context<Self>,
 883    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
 884        let project_state = self.projects.get(&project.entity_id());
 885
 886        let index_state = project_state.map(|state| {
 887            state
 888                .syntax_index
 889                .read_with(cx, |index, _cx| index.state().clone())
 890        });
 891        let options = self.options.clone();
 892        let snapshot = buffer.read(cx).snapshot();
 893        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
 894            return Task::ready(Err(anyhow!("No file path for excerpt")));
 895        };
 896        let worktree_snapshots = project
 897            .read(cx)
 898            .worktrees(cx)
 899            .map(|worktree| worktree.read(cx).snapshot())
 900            .collect::<Vec<_>>();
 901
 902        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
 903            let mut path = f.worktree.read(cx).absolutize(&f.path);
 904            if path.pop() { Some(path) } else { None }
 905        });
 906
 907        cx.background_spawn(async move {
 908            let index_state = if let Some(index_state) = index_state {
 909                Some(index_state.lock_owned().await)
 910            } else {
 911                None
 912            };
 913
 914            let cursor_point = position.to_point(&snapshot);
 915
 916            let debug_info = true;
 917            EditPredictionContext::gather_context(
 918                cursor_point,
 919                &snapshot,
 920                parent_abs_path.as_deref(),
 921                &options.context,
 922                index_state.as_deref(),
 923            )
 924            .context("Failed to select excerpt")
 925            .map(|context| {
 926                make_cloud_request(
 927                    excerpt_path.into(),
 928                    context,
 929                    // TODO pass everything
 930                    Vec::new(),
 931                    false,
 932                    Vec::new(),
 933                    false,
 934                    None,
 935                    debug_info,
 936                    &worktree_snapshots,
 937                    index_state.as_deref(),
 938                    Some(options.max_prompt_bytes),
 939                    options.prompt_format,
 940                )
 941            })
 942        })
 943    }
 944
 945    pub fn wait_for_initial_indexing(
 946        &mut self,
 947        project: &Entity<Project>,
 948        cx: &mut App,
 949    ) -> Task<Result<()>> {
 950        let zeta_project = self.get_or_init_zeta_project(project, cx);
 951        zeta_project
 952            .syntax_index
 953            .read(cx)
 954            .wait_for_initial_file_indexing(cx)
 955    }
 956}
 957
 958#[derive(Error, Debug)]
 959#[error(
 960    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
 961)]
 962pub struct ZedUpdateRequiredError {
 963    minimum_version: SemanticVersion,
 964}
 965
 966fn make_cloud_request(
 967    excerpt_path: Arc<Path>,
 968    context: EditPredictionContext,
 969    events: Vec<predict_edits_v3::Event>,
 970    can_collect_data: bool,
 971    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
 972    diagnostic_groups_truncated: bool,
 973    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
 974    debug_info: bool,
 975    worktrees: &Vec<worktree::Snapshot>,
 976    index_state: Option<&SyntaxIndexState>,
 977    prompt_max_bytes: Option<usize>,
 978    prompt_format: PromptFormat,
 979) -> predict_edits_v3::PredictEditsRequest {
 980    let mut signatures = Vec::new();
 981    let mut declaration_to_signature_index = HashMap::default();
 982    let mut referenced_declarations = Vec::new();
 983
 984    for snippet in context.declarations {
 985        let project_entry_id = snippet.declaration.project_entry_id();
 986        let Some(path) = worktrees.iter().find_map(|worktree| {
 987            worktree.entry_for_id(project_entry_id).map(|entry| {
 988                let mut full_path = RelPathBuf::new();
 989                full_path.push(worktree.root_name());
 990                full_path.push(&entry.path);
 991                full_path
 992            })
 993        }) else {
 994            continue;
 995        };
 996
 997        let parent_index = index_state.and_then(|index_state| {
 998            snippet.declaration.parent().and_then(|parent| {
 999                add_signature(
1000                    parent,
1001                    &mut declaration_to_signature_index,
1002                    &mut signatures,
1003                    index_state,
1004                )
1005            })
1006        });
1007
1008        let (text, text_is_truncated) = snippet.declaration.item_text();
1009        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1010            path: path.as_std_path().into(),
1011            text: text.into(),
1012            range: snippet.declaration.item_line_range(),
1013            text_is_truncated,
1014            signature_range: snippet.declaration.signature_range_in_item_text(),
1015            parent_index,
1016            signature_score: snippet.score(DeclarationStyle::Signature),
1017            declaration_score: snippet.score(DeclarationStyle::Declaration),
1018            score_components: snippet.components,
1019        });
1020    }
1021
1022    let excerpt_parent = index_state.and_then(|index_state| {
1023        context
1024            .excerpt
1025            .parent_declarations
1026            .last()
1027            .and_then(|(parent, _)| {
1028                add_signature(
1029                    *parent,
1030                    &mut declaration_to_signature_index,
1031                    &mut signatures,
1032                    index_state,
1033                )
1034            })
1035    });
1036
1037    predict_edits_v3::PredictEditsRequest {
1038        excerpt_path,
1039        excerpt: context.excerpt_text.body,
1040        excerpt_line_range: context.excerpt.line_range,
1041        excerpt_range: context.excerpt.range,
1042        cursor_point: predict_edits_v3::Point {
1043            line: predict_edits_v3::Line(context.cursor_point.row),
1044            column: context.cursor_point.column,
1045        },
1046        referenced_declarations,
1047        signatures,
1048        excerpt_parent,
1049        events,
1050        can_collect_data,
1051        diagnostic_groups,
1052        diagnostic_groups_truncated,
1053        git_info,
1054        debug_info,
1055        prompt_max_bytes,
1056        prompt_format,
1057    }
1058}
1059
1060fn add_signature(
1061    declaration_id: DeclarationId,
1062    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1063    signatures: &mut Vec<Signature>,
1064    index: &SyntaxIndexState,
1065) -> Option<usize> {
1066    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1067        return Some(*signature_index);
1068    }
1069    let Some(parent_declaration) = index.declaration(declaration_id) else {
1070        log::error!("bug: missing parent declaration");
1071        return None;
1072    };
1073    let parent_index = parent_declaration.parent().and_then(|parent| {
1074        add_signature(parent, declaration_to_signature_index, signatures, index)
1075    });
1076    let (text, text_is_truncated) = parent_declaration.signature_text();
1077    let signature_index = signatures.len();
1078    signatures.push(Signature {
1079        text: text.into(),
1080        text_is_truncated,
1081        parent_index,
1082        range: parent_declaration.signature_line_range(),
1083    });
1084    declaration_to_signature_index.insert(declaration_id, signature_index);
1085    Some(signature_index)
1086}
1087
1088#[cfg(test)]
1089mod tests {
1090    use std::{
1091        path::{Path, PathBuf},
1092        sync::Arc,
1093    };
1094
1095    use client::UserStore;
1096    use clock::FakeSystemClock;
1097    use cloud_llm_client::predict_edits_v3::{self, Point};
1098    use edit_prediction_context::Line;
1099    use futures::{
1100        AsyncReadExt, StreamExt,
1101        channel::{mpsc, oneshot},
1102    };
1103    use gpui::{
1104        Entity, TestAppContext,
1105        http_client::{FakeHttpClient, Response},
1106        prelude::*,
1107    };
1108    use indoc::indoc;
1109    use language::{LanguageServerId, OffsetRangeExt as _};
1110    use pretty_assertions::{assert_eq, assert_matches};
1111    use project::{FakeFs, Project};
1112    use serde_json::json;
1113    use settings::SettingsStore;
1114    use util::path;
1115    use uuid::Uuid;
1116
1117    use crate::{BufferEditPrediction, Zeta};
1118
1119    #[gpui::test]
1120    async fn test_current_state(cx: &mut TestAppContext) {
1121        let (zeta, mut req_rx) = init_test(cx);
1122        let fs = FakeFs::new(cx.executor());
1123        fs.insert_tree(
1124            "/root",
1125            json!({
1126                "1.txt": "Hello!\nHow\nBye",
1127                "2.txt": "Hola!\nComo\nAdios"
1128            }),
1129        )
1130        .await;
1131        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1132
1133        zeta.update(cx, |zeta, cx| {
1134            zeta.register_project(&project, cx);
1135        });
1136
1137        let buffer1 = project
1138            .update(cx, |project, cx| {
1139                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1140                project.open_buffer(path, cx)
1141            })
1142            .await
1143            .unwrap();
1144        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1145        let position = snapshot1.anchor_before(language::Point::new(1, 3));
1146
1147        // Prediction for current file
1148
1149        let prediction_task = zeta.update(cx, |zeta, cx| {
1150            zeta.refresh_prediction(&project, &buffer1, position, cx)
1151        });
1152        let (_request, respond_tx) = req_rx.next().await.unwrap();
1153        respond_tx
1154            .send(predict_edits_v3::PredictEditsResponse {
1155                request_id: Uuid::new_v4(),
1156                edits: vec![predict_edits_v3::Edit {
1157                    path: Path::new(path!("root/1.txt")).into(),
1158                    range: Line(0)..Line(snapshot1.max_point().row + 1),
1159                    content: "Hello!\nHow are you?\nBye".into(),
1160                }],
1161                debug_info: None,
1162            })
1163            .unwrap();
1164        prediction_task.await.unwrap();
1165
1166        zeta.read_with(cx, |zeta, cx| {
1167            let prediction = zeta
1168                .current_prediction_for_buffer(&buffer1, &project, cx)
1169                .unwrap();
1170            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1171        });
1172
1173        // Prediction for another file
1174        let prediction_task = zeta.update(cx, |zeta, cx| {
1175            zeta.refresh_prediction(&project, &buffer1, position, cx)
1176        });
1177        let (_request, respond_tx) = req_rx.next().await.unwrap();
1178        respond_tx
1179            .send(predict_edits_v3::PredictEditsResponse {
1180                request_id: Uuid::new_v4(),
1181                edits: vec![predict_edits_v3::Edit {
1182                    path: Path::new(path!("root/2.txt")).into(),
1183                    range: Line(0)..Line(snapshot1.max_point().row + 1),
1184                    content: "Hola!\nComo estas?\nAdios".into(),
1185                }],
1186                debug_info: None,
1187            })
1188            .unwrap();
1189        prediction_task.await.unwrap();
1190        zeta.read_with(cx, |zeta, cx| {
1191            let prediction = zeta
1192                .current_prediction_for_buffer(&buffer1, &project, cx)
1193                .unwrap();
1194            assert_matches!(
1195                prediction,
1196                BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1197            );
1198        });
1199
1200        let buffer2 = project
1201            .update(cx, |project, cx| {
1202                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1203                project.open_buffer(path, cx)
1204            })
1205            .await
1206            .unwrap();
1207
1208        zeta.read_with(cx, |zeta, cx| {
1209            let prediction = zeta
1210                .current_prediction_for_buffer(&buffer2, &project, cx)
1211                .unwrap();
1212            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1213        });
1214    }
1215
1216    #[gpui::test]
1217    async fn test_simple_request(cx: &mut TestAppContext) {
1218        let (zeta, mut req_rx) = init_test(cx);
1219        let fs = FakeFs::new(cx.executor());
1220        fs.insert_tree(
1221            "/root",
1222            json!({
1223                "foo.md":  "Hello!\nHow\nBye"
1224            }),
1225        )
1226        .await;
1227        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1228
1229        let buffer = project
1230            .update(cx, |project, cx| {
1231                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1232                project.open_buffer(path, cx)
1233            })
1234            .await
1235            .unwrap();
1236        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1237        let position = snapshot.anchor_before(language::Point::new(1, 3));
1238
1239        let prediction_task = zeta.update(cx, |zeta, cx| {
1240            zeta.request_prediction(&project, &buffer, position, cx)
1241        });
1242
1243        let (request, respond_tx) = req_rx.next().await.unwrap();
1244        assert_eq!(
1245            request.excerpt_path.as_ref(),
1246            Path::new(path!("root/foo.md"))
1247        );
1248        assert_eq!(
1249            request.cursor_point,
1250            Point {
1251                line: Line(1),
1252                column: 3
1253            }
1254        );
1255
1256        respond_tx
1257            .send(predict_edits_v3::PredictEditsResponse {
1258                request_id: Uuid::new_v4(),
1259                edits: vec![predict_edits_v3::Edit {
1260                    path: Path::new(path!("root/foo.md")).into(),
1261                    range: Line(0)..Line(snapshot.max_point().row + 1),
1262                    content: "Hello!\nHow are you?\nBye".into(),
1263                }],
1264                debug_info: None,
1265            })
1266            .unwrap();
1267
1268        let prediction = prediction_task.await.unwrap().unwrap();
1269
1270        assert_eq!(prediction.edits.len(), 1);
1271        assert_eq!(
1272            prediction.edits[0].0.to_point(&snapshot).start,
1273            language::Point::new(1, 3)
1274        );
1275        assert_eq!(prediction.edits[0].1, " are you?");
1276    }
1277
1278    #[gpui::test]
1279    async fn test_request_events(cx: &mut TestAppContext) {
1280        let (zeta, mut req_rx) = init_test(cx);
1281        let fs = FakeFs::new(cx.executor());
1282        fs.insert_tree(
1283            "/root",
1284            json!({
1285                "foo.md": "Hello!\n\nBye"
1286            }),
1287        )
1288        .await;
1289        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1290
1291        let buffer = project
1292            .update(cx, |project, cx| {
1293                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1294                project.open_buffer(path, cx)
1295            })
1296            .await
1297            .unwrap();
1298
1299        zeta.update(cx, |zeta, cx| {
1300            zeta.register_buffer(&buffer, &project, cx);
1301        });
1302
1303        buffer.update(cx, |buffer, cx| {
1304            buffer.edit(vec![(7..7, "How")], None, cx);
1305        });
1306
1307        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1308        let position = snapshot.anchor_before(language::Point::new(1, 3));
1309
1310        let prediction_task = zeta.update(cx, |zeta, cx| {
1311            zeta.request_prediction(&project, &buffer, position, cx)
1312        });
1313
1314        let (request, respond_tx) = req_rx.next().await.unwrap();
1315
1316        assert_eq!(request.events.len(), 1);
1317        assert_eq!(
1318            request.events[0],
1319            predict_edits_v3::Event::BufferChange {
1320                path: Some(PathBuf::from(path!("root/foo.md"))),
1321                old_path: None,
1322                diff: indoc! {"
1323                        @@ -1,3 +1,3 @@
1324                         Hello!
1325                        -
1326                        +How
1327                         Bye
1328                    "}
1329                .to_string(),
1330                predicted: false
1331            }
1332        );
1333
1334        respond_tx
1335            .send(predict_edits_v3::PredictEditsResponse {
1336                request_id: Uuid::new_v4(),
1337                edits: vec![predict_edits_v3::Edit {
1338                    path: Path::new(path!("root/foo.md")).into(),
1339                    range: Line(0)..Line(snapshot.max_point().row + 1),
1340                    content: "Hello!\nHow are you?\nBye".into(),
1341                }],
1342                debug_info: None,
1343            })
1344            .unwrap();
1345
1346        let prediction = prediction_task.await.unwrap().unwrap();
1347
1348        assert_eq!(prediction.edits.len(), 1);
1349        assert_eq!(
1350            prediction.edits[0].0.to_point(&snapshot).start,
1351            language::Point::new(1, 3)
1352        );
1353        assert_eq!(prediction.edits[0].1, " are you?");
1354    }
1355
1356    #[gpui::test]
1357    async fn test_request_diagnostics(cx: &mut TestAppContext) {
1358        let (zeta, mut req_rx) = init_test(cx);
1359        let fs = FakeFs::new(cx.executor());
1360        fs.insert_tree(
1361            "/root",
1362            json!({
1363                "foo.md": "Hello!\nBye"
1364            }),
1365        )
1366        .await;
1367        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1368
1369        let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1370        let diagnostic = lsp::Diagnostic {
1371            range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1372            severity: Some(lsp::DiagnosticSeverity::ERROR),
1373            message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1374            ..Default::default()
1375        };
1376
1377        project.update(cx, |project, cx| {
1378            project.lsp_store().update(cx, |lsp_store, cx| {
1379                // Create some diagnostics
1380                lsp_store
1381                    .update_diagnostics(
1382                        LanguageServerId(0),
1383                        lsp::PublishDiagnosticsParams {
1384                            uri: path_to_buffer_uri.clone(),
1385                            diagnostics: vec![diagnostic],
1386                            version: None,
1387                        },
1388                        None,
1389                        language::DiagnosticSourceKind::Pushed,
1390                        &[],
1391                        cx,
1392                    )
1393                    .unwrap();
1394            });
1395        });
1396
1397        let buffer = project
1398            .update(cx, |project, cx| {
1399                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1400                project.open_buffer(path, cx)
1401            })
1402            .await
1403            .unwrap();
1404
1405        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1406        let position = snapshot.anchor_before(language::Point::new(0, 0));
1407
1408        let _prediction_task = zeta.update(cx, |zeta, cx| {
1409            zeta.request_prediction(&project, &buffer, position, cx)
1410        });
1411
1412        let (request, _respond_tx) = req_rx.next().await.unwrap();
1413
1414        assert_eq!(request.diagnostic_groups.len(), 1);
1415        let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1416            .unwrap();
1417        // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1418        assert_eq!(
1419            value,
1420            json!({
1421                "entries": [{
1422                    "range": {
1423                        "start": 8,
1424                        "end": 10
1425                    },
1426                    "diagnostic": {
1427                        "source": null,
1428                        "code": null,
1429                        "code_description": null,
1430                        "severity": 1,
1431                        "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1432                        "markdown": null,
1433                        "group_id": 0,
1434                        "is_primary": true,
1435                        "is_disk_based": false,
1436                        "is_unnecessary": false,
1437                        "source_kind": "Pushed",
1438                        "data": null,
1439                        "underline": true
1440                    }
1441                }],
1442                "primary_ix": 0
1443            })
1444        );
1445    }
1446
1447    fn init_test(
1448        cx: &mut TestAppContext,
1449    ) -> (
1450        Entity<Zeta>,
1451        mpsc::UnboundedReceiver<(
1452            predict_edits_v3::PredictEditsRequest,
1453            oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1454        )>,
1455    ) {
1456        cx.update(move |cx| {
1457            let settings_store = SettingsStore::test(cx);
1458            cx.set_global(settings_store);
1459            language::init(cx);
1460            Project::init_settings(cx);
1461
1462            let (req_tx, req_rx) = mpsc::unbounded();
1463
1464            let http_client = FakeHttpClient::create({
1465                move |req| {
1466                    let uri = req.uri().path().to_string();
1467                    let mut body = req.into_body();
1468                    let req_tx = req_tx.clone();
1469                    async move {
1470                        let resp = match uri.as_str() {
1471                            "/client/llm_tokens" => serde_json::to_string(&json!({
1472                                "token": "test"
1473                            }))
1474                            .unwrap(),
1475                            "/predict_edits/v3" => {
1476                                let mut buf = Vec::new();
1477                                body.read_to_end(&mut buf).await.ok();
1478                                let req = serde_json::from_slice(&buf).unwrap();
1479
1480                                let (res_tx, res_rx) = oneshot::channel();
1481                                req_tx.unbounded_send((req, res_tx)).unwrap();
1482                                serde_json::to_string(&res_rx.await?).unwrap()
1483                            }
1484                            _ => {
1485                                panic!("Unexpected path: {}", uri)
1486                            }
1487                        };
1488
1489                        Ok(Response::builder().body(resp.into()).unwrap())
1490                    }
1491                }
1492            });
1493
1494            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1495            client.cloud_client().set_credentials(1, "test".into());
1496
1497            language_model::init(client.clone(), cx);
1498
1499            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1500            let zeta = Zeta::global(&client, &user_store, cx);
1501
1502            (zeta, req_rx)
1503        })
1504    }
1505}