zeta2.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use chrono::TimeDelta;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
   7    ZED_VERSION_HEADER_NAME,
   8};
   9use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, PlannedPrompt};
  10use edit_prediction_context::{
  11    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  12    EditPredictionExcerptOptions, EditPredictionScoreOptions, SyntaxIndex, SyntaxIndexState,
  13};
  14use feature_flags::FeatureFlag;
  15use futures::AsyncReadExt as _;
  16use futures::channel::{mpsc, oneshot};
  17use gpui::http_client::{AsyncBody, Method};
  18use gpui::{
  19    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  20    http_client, prelude::*,
  21};
  22use language::BufferSnapshot;
  23use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
  24use language_model::{LlmApiToken, RefreshLlmTokenListener};
  25use project::Project;
  26use release_channel::AppVersion;
  27use serde::de::DeserializeOwned;
  28use std::collections::{HashMap, VecDeque, hash_map};
  29use std::path::Path;
  30use std::str::FromStr as _;
  31use std::sync::Arc;
  32use std::time::{Duration, Instant};
  33use thiserror::Error;
  34use util::rel_path::RelPathBuf;
  35use util::some_or_debug_panic;
  36use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  37
  38mod prediction;
  39mod provider;
  40
  41use crate::prediction::EditPrediction;
  42pub use provider::ZetaEditPredictionProvider;
  43
  44const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  45
  46/// Maximum number of events to track.
  47const MAX_EVENT_COUNT: usize = 16;
  48
  49pub const DEFAULT_CONTEXT_OPTIONS: EditPredictionContextOptions = EditPredictionContextOptions {
  50    use_imports: true,
  51    max_retrieved_declarations: 0,
  52    excerpt: EditPredictionExcerptOptions {
  53        max_bytes: 512,
  54        min_bytes: 128,
  55        target_before_cursor_over_total_bytes: 0.5,
  56    },
  57    score: EditPredictionScoreOptions {
  58        omit_excerpt_overlaps: true,
  59    },
  60};
  61
  62pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  63    context: DEFAULT_CONTEXT_OPTIONS,
  64    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
  65    max_diagnostic_bytes: 2048,
  66    prompt_format: PromptFormat::DEFAULT,
  67    file_indexing_parallelism: 1,
  68};
  69
  70pub struct Zeta2FeatureFlag;
  71
  72impl FeatureFlag for Zeta2FeatureFlag {
  73    const NAME: &'static str = "zeta2";
  74
  75    fn enabled_for_staff() -> bool {
  76        false
  77    }
  78}
  79
  80#[derive(Clone)]
  81struct ZetaGlobal(Entity<Zeta>);
  82
  83impl Global for ZetaGlobal {}
  84
  85pub struct Zeta {
  86    client: Arc<Client>,
  87    user_store: Entity<UserStore>,
  88    llm_token: LlmApiToken,
  89    _llm_token_subscription: Subscription,
  90    projects: HashMap<EntityId, ZetaProject>,
  91    options: ZetaOptions,
  92    update_required: bool,
  93    debug_tx: Option<mpsc::UnboundedSender<PredictionDebugInfo>>,
  94}
  95
  96#[derive(Debug, Clone, PartialEq)]
  97pub struct ZetaOptions {
  98    pub context: EditPredictionContextOptions,
  99    pub max_prompt_bytes: usize,
 100    pub max_diagnostic_bytes: usize,
 101    pub prompt_format: predict_edits_v3::PromptFormat,
 102    pub file_indexing_parallelism: usize,
 103}
 104
 105pub struct PredictionDebugInfo {
 106    pub context: EditPredictionContext,
 107    pub retrieval_time: TimeDelta,
 108    pub buffer: WeakEntity<Buffer>,
 109    pub position: language::Anchor,
 110    pub local_prompt: Result<String, String>,
 111    pub response_rx: oneshot::Receiver<Result<RequestDebugInfo, String>>,
 112}
 113
 114pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 115
 116struct ZetaProject {
 117    syntax_index: Entity<SyntaxIndex>,
 118    events: VecDeque<Event>,
 119    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 120    current_prediction: Option<CurrentEditPrediction>,
 121}
 122
 123#[derive(Debug, Clone)]
 124struct CurrentEditPrediction {
 125    pub requested_by_buffer_id: EntityId,
 126    pub prediction: EditPrediction,
 127}
 128
 129impl CurrentEditPrediction {
 130    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 131        let Some(new_edits) = self
 132            .prediction
 133            .interpolate(&self.prediction.buffer.read(cx))
 134        else {
 135            return false;
 136        };
 137
 138        if self.prediction.buffer != old_prediction.prediction.buffer {
 139            return true;
 140        }
 141
 142        let Some(old_edits) = old_prediction
 143            .prediction
 144            .interpolate(&old_prediction.prediction.buffer.read(cx))
 145        else {
 146            return true;
 147        };
 148
 149        // This reduces the occurrence of UI thrash from replacing edits
 150        //
 151        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 152        if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
 153            && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
 154            && old_edits.len() == 1
 155            && new_edits.len() == 1
 156        {
 157            let (old_range, old_text) = &old_edits[0];
 158            let (new_range, new_text) = &new_edits[0];
 159            new_range == old_range && new_text.starts_with(old_text)
 160        } else {
 161            true
 162        }
 163    }
 164}
 165
 166/// A prediction from the perspective of a buffer.
 167#[derive(Debug)]
 168enum BufferEditPrediction<'a> {
 169    Local { prediction: &'a EditPrediction },
 170    Jump { prediction: &'a EditPrediction },
 171}
 172
 173struct RegisteredBuffer {
 174    snapshot: BufferSnapshot,
 175    _subscriptions: [gpui::Subscription; 2],
 176}
 177
 178#[derive(Clone)]
 179pub enum Event {
 180    BufferChange {
 181        old_snapshot: BufferSnapshot,
 182        new_snapshot: BufferSnapshot,
 183        timestamp: Instant,
 184    },
 185}
 186
 187impl Zeta {
 188    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 189        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 190    }
 191
 192    pub fn global(
 193        client: &Arc<Client>,
 194        user_store: &Entity<UserStore>,
 195        cx: &mut App,
 196    ) -> Entity<Self> {
 197        cx.try_global::<ZetaGlobal>()
 198            .map(|global| global.0.clone())
 199            .unwrap_or_else(|| {
 200                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 201                cx.set_global(ZetaGlobal(zeta.clone()));
 202                zeta
 203            })
 204    }
 205
 206    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 207        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 208
 209        Self {
 210            projects: HashMap::new(),
 211            client,
 212            user_store,
 213            options: DEFAULT_OPTIONS,
 214            llm_token: LlmApiToken::default(),
 215            _llm_token_subscription: cx.subscribe(
 216                &refresh_llm_token_listener,
 217                |this, _listener, _event, cx| {
 218                    let client = this.client.clone();
 219                    let llm_token = this.llm_token.clone();
 220                    cx.spawn(async move |_this, _cx| {
 221                        llm_token.refresh(&client).await?;
 222                        anyhow::Ok(())
 223                    })
 224                    .detach_and_log_err(cx);
 225                },
 226            ),
 227            update_required: false,
 228            debug_tx: None,
 229        }
 230    }
 231
 232    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<PredictionDebugInfo> {
 233        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 234        self.debug_tx = Some(debug_watch_tx);
 235        debug_watch_rx
 236    }
 237
 238    pub fn options(&self) -> &ZetaOptions {
 239        &self.options
 240    }
 241
 242    pub fn set_options(&mut self, options: ZetaOptions) {
 243        self.options = options;
 244    }
 245
 246    pub fn clear_history(&mut self) {
 247        for zeta_project in self.projects.values_mut() {
 248            zeta_project.events.clear();
 249        }
 250    }
 251
 252    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 253        self.user_store.read(cx).edit_prediction_usage()
 254    }
 255
 256    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 257        self.get_or_init_zeta_project(project, cx);
 258    }
 259
 260    pub fn register_buffer(
 261        &mut self,
 262        buffer: &Entity<Buffer>,
 263        project: &Entity<Project>,
 264        cx: &mut Context<Self>,
 265    ) {
 266        let zeta_project = self.get_or_init_zeta_project(project, cx);
 267        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 268    }
 269
 270    fn get_or_init_zeta_project(
 271        &mut self,
 272        project: &Entity<Project>,
 273        cx: &mut App,
 274    ) -> &mut ZetaProject {
 275        self.projects
 276            .entry(project.entity_id())
 277            .or_insert_with(|| ZetaProject {
 278                syntax_index: cx.new(|cx| {
 279                    SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 280                }),
 281                events: VecDeque::new(),
 282                registered_buffers: HashMap::new(),
 283                current_prediction: None,
 284            })
 285    }
 286
 287    fn register_buffer_impl<'a>(
 288        zeta_project: &'a mut ZetaProject,
 289        buffer: &Entity<Buffer>,
 290        project: &Entity<Project>,
 291        cx: &mut Context<Self>,
 292    ) -> &'a mut RegisteredBuffer {
 293        let buffer_id = buffer.entity_id();
 294        match zeta_project.registered_buffers.entry(buffer_id) {
 295            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 296            hash_map::Entry::Vacant(entry) => {
 297                let snapshot = buffer.read(cx).snapshot();
 298                let project_entity_id = project.entity_id();
 299                entry.insert(RegisteredBuffer {
 300                    snapshot,
 301                    _subscriptions: [
 302                        cx.subscribe(buffer, {
 303                            let project = project.downgrade();
 304                            move |this, buffer, event, cx| {
 305                                if let language::BufferEvent::Edited = event
 306                                    && let Some(project) = project.upgrade()
 307                                {
 308                                    this.report_changes_for_buffer(&buffer, &project, cx);
 309                                }
 310                            }
 311                        }),
 312                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 313                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 314                            else {
 315                                return;
 316                            };
 317                            zeta_project.registered_buffers.remove(&buffer_id);
 318                        }),
 319                    ],
 320                })
 321            }
 322        }
 323    }
 324
 325    fn report_changes_for_buffer(
 326        &mut self,
 327        buffer: &Entity<Buffer>,
 328        project: &Entity<Project>,
 329        cx: &mut Context<Self>,
 330    ) -> BufferSnapshot {
 331        let zeta_project = self.get_or_init_zeta_project(project, cx);
 332        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 333
 334        let new_snapshot = buffer.read(cx).snapshot();
 335        if new_snapshot.version != registered_buffer.snapshot.version {
 336            let old_snapshot =
 337                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 338            Self::push_event(
 339                zeta_project,
 340                Event::BufferChange {
 341                    old_snapshot,
 342                    new_snapshot: new_snapshot.clone(),
 343                    timestamp: Instant::now(),
 344                },
 345            );
 346        }
 347
 348        new_snapshot
 349    }
 350
 351    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
 352        let events = &mut zeta_project.events;
 353
 354        if let Some(Event::BufferChange {
 355            new_snapshot: last_new_snapshot,
 356            timestamp: last_timestamp,
 357            ..
 358        }) = events.back_mut()
 359        {
 360            // Coalesce edits for the same buffer when they happen one after the other.
 361            let Event::BufferChange {
 362                old_snapshot,
 363                new_snapshot,
 364                timestamp,
 365            } = &event;
 366
 367            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 368                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 369                && old_snapshot.version == last_new_snapshot.version
 370            {
 371                *last_new_snapshot = new_snapshot.clone();
 372                *last_timestamp = *timestamp;
 373                return;
 374            }
 375        }
 376
 377        if events.len() >= MAX_EVENT_COUNT {
 378            // These are halved instead of popping to improve prompt caching.
 379            events.drain(..MAX_EVENT_COUNT / 2);
 380        }
 381
 382        events.push_back(event);
 383    }
 384
 385    fn current_prediction_for_buffer(
 386        &self,
 387        buffer: &Entity<Buffer>,
 388        project: &Entity<Project>,
 389        cx: &App,
 390    ) -> Option<BufferEditPrediction<'_>> {
 391        let project_state = self.projects.get(&project.entity_id())?;
 392
 393        let CurrentEditPrediction {
 394            requested_by_buffer_id,
 395            prediction,
 396        } = project_state.current_prediction.as_ref()?;
 397
 398        if prediction.targets_buffer(buffer.read(cx), cx) {
 399            Some(BufferEditPrediction::Local { prediction })
 400        } else if *requested_by_buffer_id == buffer.entity_id() {
 401            Some(BufferEditPrediction::Jump { prediction })
 402        } else {
 403            None
 404        }
 405    }
 406
 407    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 408        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 409            return;
 410        };
 411
 412        let Some(prediction) = project_state.current_prediction.take() else {
 413            return;
 414        };
 415        let request_id = prediction.prediction.id.into();
 416
 417        let client = self.client.clone();
 418        let llm_token = self.llm_token.clone();
 419        let app_version = AppVersion::global(cx);
 420        cx.spawn(async move |this, cx| {
 421            let url = if let Ok(predict_edits_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
 422                http_client::Url::parse(&predict_edits_url)?
 423            } else {
 424                client
 425                    .http_client()
 426                    .build_zed_llm_url("/predict_edits/accept", &[])?
 427            };
 428
 429            let response = cx
 430                .background_spawn(Self::send_api_request::<()>(
 431                    move |builder| {
 432                        let req = builder.uri(url.as_ref()).body(
 433                            serde_json::to_string(&AcceptEditPredictionBody { request_id })?.into(),
 434                        );
 435                        Ok(req?)
 436                    },
 437                    client,
 438                    llm_token,
 439                    app_version,
 440                ))
 441                .await;
 442
 443            Self::handle_api_response(&this, response, cx)?;
 444            anyhow::Ok(())
 445        })
 446        .detach_and_log_err(cx);
 447    }
 448
 449    fn discard_current_prediction(&mut self, project: &Entity<Project>) {
 450        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 451            project_state.current_prediction.take();
 452        };
 453    }
 454
 455    pub fn refresh_prediction(
 456        &mut self,
 457        project: &Entity<Project>,
 458        buffer: &Entity<Buffer>,
 459        position: language::Anchor,
 460        cx: &mut Context<Self>,
 461    ) -> Task<Result<()>> {
 462        let request_task = self.request_prediction(project, buffer, position, cx);
 463        let buffer = buffer.clone();
 464        let project = project.clone();
 465
 466        cx.spawn(async move |this, cx| {
 467            if let Some(prediction) = request_task.await? {
 468                this.update(cx, |this, cx| {
 469                    let project_state = this
 470                        .projects
 471                        .get_mut(&project.entity_id())
 472                        .context("Project not found")?;
 473
 474                    let new_prediction = CurrentEditPrediction {
 475                        requested_by_buffer_id: buffer.entity_id(),
 476                        prediction: prediction,
 477                    };
 478
 479                    if project_state
 480                        .current_prediction
 481                        .as_ref()
 482                        .is_none_or(|old_prediction| {
 483                            new_prediction.should_replace_prediction(&old_prediction, cx)
 484                        })
 485                    {
 486                        project_state.current_prediction = Some(new_prediction);
 487                    }
 488                    anyhow::Ok(())
 489                })??;
 490            }
 491            Ok(())
 492        })
 493    }
 494
 495    fn request_prediction(
 496        &mut self,
 497        project: &Entity<Project>,
 498        buffer: &Entity<Buffer>,
 499        position: language::Anchor,
 500        cx: &mut Context<Self>,
 501    ) -> Task<Result<Option<EditPrediction>>> {
 502        let project_state = self.projects.get(&project.entity_id());
 503
 504        let index_state = project_state.map(|state| {
 505            state
 506                .syntax_index
 507                .read_with(cx, |index, _cx| index.state().clone())
 508        });
 509        let options = self.options.clone();
 510        let snapshot = buffer.read(cx).snapshot();
 511        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx).into()) else {
 512            return Task::ready(Err(anyhow!("No file path for excerpt")));
 513        };
 514        let client = self.client.clone();
 515        let llm_token = self.llm_token.clone();
 516        let app_version = AppVersion::global(cx);
 517        let worktree_snapshots = project
 518            .read(cx)
 519            .worktrees(cx)
 520            .map(|worktree| worktree.read(cx).snapshot())
 521            .collect::<Vec<_>>();
 522        let debug_tx = self.debug_tx.clone();
 523
 524        let events = project_state
 525            .map(|state| {
 526                state
 527                    .events
 528                    .iter()
 529                    .filter_map(|event| match event {
 530                        Event::BufferChange {
 531                            old_snapshot,
 532                            new_snapshot,
 533                            ..
 534                        } => {
 535                            let path = new_snapshot.file().map(|f| f.full_path(cx));
 536
 537                            let old_path = old_snapshot.file().and_then(|f| {
 538                                let old_path = f.full_path(cx);
 539                                if Some(&old_path) != path.as_ref() {
 540                                    Some(old_path)
 541                                } else {
 542                                    None
 543                                }
 544                            });
 545
 546                            // TODO [zeta2] move to bg?
 547                            let diff =
 548                                language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
 549
 550                            if path == old_path && diff.is_empty() {
 551                                None
 552                            } else {
 553                                Some(predict_edits_v3::Event::BufferChange {
 554                                    old_path,
 555                                    path,
 556                                    diff,
 557                                    //todo: Actually detect if this edit was predicted or not
 558                                    predicted: false,
 559                                })
 560                            }
 561                        }
 562                    })
 563                    .collect::<Vec<_>>()
 564            })
 565            .unwrap_or_default();
 566
 567        let diagnostics = snapshot.diagnostic_sets().clone();
 568
 569        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
 570            let mut path = f.worktree.read(cx).absolutize(&f.path);
 571            if path.pop() { Some(path) } else { None }
 572        });
 573
 574        let request_task = cx.background_spawn({
 575            let snapshot = snapshot.clone();
 576            let buffer = buffer.clone();
 577            async move {
 578                let index_state = if let Some(index_state) = index_state {
 579                    Some(index_state.lock_owned().await)
 580                } else {
 581                    None
 582                };
 583
 584                let cursor_offset = position.to_offset(&snapshot);
 585                let cursor_point = cursor_offset.to_point(&snapshot);
 586
 587                let before_retrieval = chrono::Utc::now();
 588
 589                let Some(context) = EditPredictionContext::gather_context(
 590                    cursor_point,
 591                    &snapshot,
 592                    parent_abs_path.as_deref(),
 593                    &options.context,
 594                    index_state.as_deref(),
 595                ) else {
 596                    return Ok((None, None));
 597                };
 598
 599                let retrieval_time = chrono::Utc::now() - before_retrieval;
 600
 601                let (diagnostic_groups, diagnostic_groups_truncated) =
 602                    Self::gather_nearby_diagnostics(
 603                        cursor_offset,
 604                        &diagnostics,
 605                        &snapshot,
 606                        options.max_diagnostic_bytes,
 607                    );
 608
 609                let debug_context = debug_tx.map(|tx| (tx, context.clone()));
 610
 611                let request = make_cloud_request(
 612                    excerpt_path,
 613                    context,
 614                    events,
 615                    // TODO data collection
 616                    false,
 617                    diagnostic_groups,
 618                    diagnostic_groups_truncated,
 619                    None,
 620                    debug_context.is_some(),
 621                    &worktree_snapshots,
 622                    index_state.as_deref(),
 623                    Some(options.max_prompt_bytes),
 624                    options.prompt_format,
 625                );
 626
 627                let debug_response_tx = if let Some((debug_tx, context)) = debug_context {
 628                    let (response_tx, response_rx) = oneshot::channel();
 629
 630                    let local_prompt = PlannedPrompt::populate(&request)
 631                        .and_then(|p| p.to_prompt_string().map(|p| p.0))
 632                        .map_err(|err| err.to_string());
 633
 634                    debug_tx
 635                        .unbounded_send(PredictionDebugInfo {
 636                            context,
 637                            retrieval_time,
 638                            buffer: buffer.downgrade(),
 639                            local_prompt,
 640                            position,
 641                            response_rx,
 642                        })
 643                        .ok();
 644                    Some(response_tx)
 645                } else {
 646                    None
 647                };
 648
 649                if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
 650                    if let Some(debug_response_tx) = debug_response_tx {
 651                        debug_response_tx
 652                            .send(Err("Request skipped".to_string()))
 653                            .ok();
 654                    }
 655                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
 656                }
 657
 658                let response =
 659                    Self::send_prediction_request(client, llm_token, app_version, request).await;
 660
 661                if let Some(debug_response_tx) = debug_response_tx {
 662                    debug_response_tx
 663                        .send(response.as_ref().map_err(|err| err.to_string()).and_then(
 664                            |response| match some_or_debug_panic(response.0.debug_info.clone()) {
 665                                Some(debug_info) => Ok(debug_info),
 666                                None => Err("Missing debug info".to_string()),
 667                            },
 668                        ))
 669                        .ok();
 670                }
 671
 672                response.map(|(res, usage)| (Some(res), usage))
 673            }
 674        });
 675
 676        let buffer = buffer.clone();
 677
 678        cx.spawn({
 679            let project = project.clone();
 680            async move |this, cx| {
 681                let Some(response) = Self::handle_api_response(&this, request_task.await, cx)?
 682                else {
 683                    return Ok(None);
 684                };
 685
 686                // TODO telemetry: duration, etc
 687                Ok(EditPrediction::from_response(response, &snapshot, &buffer, &project, cx).await)
 688            }
 689        })
 690    }
 691
 692    async fn send_prediction_request(
 693        client: Arc<Client>,
 694        llm_token: LlmApiToken,
 695        app_version: SemanticVersion,
 696        request: predict_edits_v3::PredictEditsRequest,
 697    ) -> Result<(
 698        predict_edits_v3::PredictEditsResponse,
 699        Option<EditPredictionUsage>,
 700    )> {
 701        let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 702            http_client::Url::parse(&predict_edits_url)?
 703        } else {
 704            client
 705                .http_client()
 706                .build_zed_llm_url("/predict_edits/v3", &[])?
 707        };
 708
 709        Self::send_api_request(
 710            |builder| {
 711                let req = builder
 712                    .uri(url.as_ref())
 713                    .body(serde_json::to_string(&request)?.into());
 714                Ok(req?)
 715            },
 716            client,
 717            llm_token,
 718            app_version,
 719        )
 720        .await
 721    }
 722
 723    fn handle_api_response<T>(
 724        this: &WeakEntity<Self>,
 725        response: Result<(T, Option<EditPredictionUsage>)>,
 726        cx: &mut gpui::AsyncApp,
 727    ) -> Result<T> {
 728        match response {
 729            Ok((data, usage)) => {
 730                if let Some(usage) = usage {
 731                    this.update(cx, |this, cx| {
 732                        this.user_store.update(cx, |user_store, cx| {
 733                            user_store.update_edit_prediction_usage(usage, cx);
 734                        });
 735                    })
 736                    .ok();
 737                }
 738                Ok(data)
 739            }
 740            Err(err) => {
 741                if err.is::<ZedUpdateRequiredError>() {
 742                    cx.update(|cx| {
 743                        this.update(cx, |this, _cx| {
 744                            this.update_required = true;
 745                        })
 746                        .ok();
 747
 748                        let error_message: SharedString = err.to_string().into();
 749                        show_app_notification(
 750                            NotificationId::unique::<ZedUpdateRequiredError>(),
 751                            cx,
 752                            move |cx| {
 753                                cx.new(|cx| {
 754                                    ErrorMessagePrompt::new(error_message.clone(), cx)
 755                                        .with_link_button("Update Zed", "https://zed.dev/releases")
 756                                })
 757                            },
 758                        );
 759                    })
 760                    .ok();
 761                }
 762                Err(err)
 763            }
 764        }
 765    }
 766
 767    async fn send_api_request<Res>(
 768        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
 769        client: Arc<Client>,
 770        llm_token: LlmApiToken,
 771        app_version: SemanticVersion,
 772    ) -> Result<(Res, Option<EditPredictionUsage>)>
 773    where
 774        Res: DeserializeOwned,
 775    {
 776        let http_client = client.http_client();
 777        let mut token = llm_token.acquire(&client).await?;
 778        let mut did_retry = false;
 779
 780        loop {
 781            let request_builder = http_client::Request::builder().method(Method::POST);
 782
 783            let request = build(
 784                request_builder
 785                    .header("Content-Type", "application/json")
 786                    .header("Authorization", format!("Bearer {}", token))
 787                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
 788            )?;
 789
 790            let mut response = http_client.send(request).await?;
 791
 792            if let Some(minimum_required_version) = response
 793                .headers()
 794                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 795                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 796            {
 797                anyhow::ensure!(
 798                    app_version >= minimum_required_version,
 799                    ZedUpdateRequiredError {
 800                        minimum_version: minimum_required_version
 801                    }
 802                );
 803            }
 804
 805            if response.status().is_success() {
 806                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 807
 808                let mut body = Vec::new();
 809                response.body_mut().read_to_end(&mut body).await?;
 810                return Ok((serde_json::from_slice(&body)?, usage));
 811            } else if !did_retry
 812                && response
 813                    .headers()
 814                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 815                    .is_some()
 816            {
 817                did_retry = true;
 818                token = llm_token.refresh(&client).await?;
 819            } else {
 820                let mut body = String::new();
 821                response.body_mut().read_to_string(&mut body).await?;
 822                anyhow::bail!(
 823                    "Request failed with status: {:?}\nBody: {}",
 824                    response.status(),
 825                    body
 826                );
 827            }
 828        }
 829    }
 830
 831    fn gather_nearby_diagnostics(
 832        cursor_offset: usize,
 833        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
 834        snapshot: &BufferSnapshot,
 835        max_diagnostics_bytes: usize,
 836    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
 837        // TODO: Could make this more efficient
 838        let mut diagnostic_groups = Vec::new();
 839        for (language_server_id, diagnostics) in diagnostic_sets {
 840            let mut groups = Vec::new();
 841            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
 842            diagnostic_groups.extend(
 843                groups
 844                    .into_iter()
 845                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
 846            );
 847        }
 848
 849        // sort by proximity to cursor
 850        diagnostic_groups.sort_by_key(|group| {
 851            let range = &group.entries[group.primary_ix].range;
 852            if range.start >= cursor_offset {
 853                range.start - cursor_offset
 854            } else if cursor_offset >= range.end {
 855                cursor_offset - range.end
 856            } else {
 857                (cursor_offset - range.start).min(range.end - cursor_offset)
 858            }
 859        });
 860
 861        let mut results = Vec::new();
 862        let mut diagnostic_groups_truncated = false;
 863        let mut diagnostics_byte_count = 0;
 864        for group in diagnostic_groups {
 865            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
 866            diagnostics_byte_count += raw_value.get().len();
 867            if diagnostics_byte_count > max_diagnostics_bytes {
 868                diagnostic_groups_truncated = true;
 869                break;
 870            }
 871            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
 872        }
 873
 874        (results, diagnostic_groups_truncated)
 875    }
 876
 877    // TODO: Dedupe with similar code in request_prediction?
 878    pub fn cloud_request_for_zeta_cli(
 879        &mut self,
 880        project: &Entity<Project>,
 881        buffer: &Entity<Buffer>,
 882        position: language::Anchor,
 883        cx: &mut Context<Self>,
 884    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
 885        let project_state = self.projects.get(&project.entity_id());
 886
 887        let index_state = project_state.map(|state| {
 888            state
 889                .syntax_index
 890                .read_with(cx, |index, _cx| index.state().clone())
 891        });
 892        let options = self.options.clone();
 893        let snapshot = buffer.read(cx).snapshot();
 894        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
 895            return Task::ready(Err(anyhow!("No file path for excerpt")));
 896        };
 897        let worktree_snapshots = project
 898            .read(cx)
 899            .worktrees(cx)
 900            .map(|worktree| worktree.read(cx).snapshot())
 901            .collect::<Vec<_>>();
 902
 903        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
 904            let mut path = f.worktree.read(cx).absolutize(&f.path);
 905            if path.pop() { Some(path) } else { None }
 906        });
 907
 908        cx.background_spawn(async move {
 909            let index_state = if let Some(index_state) = index_state {
 910                Some(index_state.lock_owned().await)
 911            } else {
 912                None
 913            };
 914
 915            let cursor_point = position.to_point(&snapshot);
 916
 917            let debug_info = true;
 918            EditPredictionContext::gather_context(
 919                cursor_point,
 920                &snapshot,
 921                parent_abs_path.as_deref(),
 922                &options.context,
 923                index_state.as_deref(),
 924            )
 925            .context("Failed to select excerpt")
 926            .map(|context| {
 927                make_cloud_request(
 928                    excerpt_path.into(),
 929                    context,
 930                    // TODO pass everything
 931                    Vec::new(),
 932                    false,
 933                    Vec::new(),
 934                    false,
 935                    None,
 936                    debug_info,
 937                    &worktree_snapshots,
 938                    index_state.as_deref(),
 939                    Some(options.max_prompt_bytes),
 940                    options.prompt_format,
 941                )
 942            })
 943        })
 944    }
 945
 946    pub fn wait_for_initial_indexing(
 947        &mut self,
 948        project: &Entity<Project>,
 949        cx: &mut App,
 950    ) -> Task<Result<()>> {
 951        let zeta_project = self.get_or_init_zeta_project(project, cx);
 952        zeta_project
 953            .syntax_index
 954            .read(cx)
 955            .wait_for_initial_file_indexing(cx)
 956    }
 957}
 958
 959#[derive(Error, Debug)]
 960#[error(
 961    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
 962)]
 963pub struct ZedUpdateRequiredError {
 964    minimum_version: SemanticVersion,
 965}
 966
 967fn make_cloud_request(
 968    excerpt_path: Arc<Path>,
 969    context: EditPredictionContext,
 970    events: Vec<predict_edits_v3::Event>,
 971    can_collect_data: bool,
 972    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
 973    diagnostic_groups_truncated: bool,
 974    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
 975    debug_info: bool,
 976    worktrees: &Vec<worktree::Snapshot>,
 977    index_state: Option<&SyntaxIndexState>,
 978    prompt_max_bytes: Option<usize>,
 979    prompt_format: PromptFormat,
 980) -> predict_edits_v3::PredictEditsRequest {
 981    let mut signatures = Vec::new();
 982    let mut declaration_to_signature_index = HashMap::default();
 983    let mut referenced_declarations = Vec::new();
 984
 985    for snippet in context.declarations {
 986        let project_entry_id = snippet.declaration.project_entry_id();
 987        let Some(path) = worktrees.iter().find_map(|worktree| {
 988            worktree.entry_for_id(project_entry_id).map(|entry| {
 989                let mut full_path = RelPathBuf::new();
 990                full_path.push(worktree.root_name());
 991                full_path.push(&entry.path);
 992                full_path
 993            })
 994        }) else {
 995            continue;
 996        };
 997
 998        let parent_index = index_state.and_then(|index_state| {
 999            snippet.declaration.parent().and_then(|parent| {
1000                add_signature(
1001                    parent,
1002                    &mut declaration_to_signature_index,
1003                    &mut signatures,
1004                    index_state,
1005                )
1006            })
1007        });
1008
1009        let (text, text_is_truncated) = snippet.declaration.item_text();
1010        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1011            path: path.as_std_path().into(),
1012            text: text.into(),
1013            range: snippet.declaration.item_line_range(),
1014            text_is_truncated,
1015            signature_range: snippet.declaration.signature_range_in_item_text(),
1016            parent_index,
1017            signature_score: snippet.score(DeclarationStyle::Signature),
1018            declaration_score: snippet.score(DeclarationStyle::Declaration),
1019            score_components: snippet.components,
1020        });
1021    }
1022
1023    let excerpt_parent = index_state.and_then(|index_state| {
1024        context
1025            .excerpt
1026            .parent_declarations
1027            .last()
1028            .and_then(|(parent, _)| {
1029                add_signature(
1030                    *parent,
1031                    &mut declaration_to_signature_index,
1032                    &mut signatures,
1033                    index_state,
1034                )
1035            })
1036    });
1037
1038    predict_edits_v3::PredictEditsRequest {
1039        excerpt_path,
1040        excerpt: context.excerpt_text.body,
1041        excerpt_line_range: context.excerpt.line_range,
1042        excerpt_range: context.excerpt.range,
1043        cursor_point: predict_edits_v3::Point {
1044            line: predict_edits_v3::Line(context.cursor_point.row),
1045            column: context.cursor_point.column,
1046        },
1047        referenced_declarations,
1048        signatures,
1049        excerpt_parent,
1050        events,
1051        can_collect_data,
1052        diagnostic_groups,
1053        diagnostic_groups_truncated,
1054        git_info,
1055        debug_info,
1056        prompt_max_bytes,
1057        prompt_format,
1058    }
1059}
1060
1061fn add_signature(
1062    declaration_id: DeclarationId,
1063    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1064    signatures: &mut Vec<Signature>,
1065    index: &SyntaxIndexState,
1066) -> Option<usize> {
1067    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1068        return Some(*signature_index);
1069    }
1070    let Some(parent_declaration) = index.declaration(declaration_id) else {
1071        log::error!("bug: missing parent declaration");
1072        return None;
1073    };
1074    let parent_index = parent_declaration.parent().and_then(|parent| {
1075        add_signature(parent, declaration_to_signature_index, signatures, index)
1076    });
1077    let (text, text_is_truncated) = parent_declaration.signature_text();
1078    let signature_index = signatures.len();
1079    signatures.push(Signature {
1080        text: text.into(),
1081        text_is_truncated,
1082        parent_index,
1083        range: parent_declaration.signature_line_range(),
1084    });
1085    declaration_to_signature_index.insert(declaration_id, signature_index);
1086    Some(signature_index)
1087}
1088
1089#[cfg(test)]
1090mod tests {
1091    use std::{
1092        path::{Path, PathBuf},
1093        sync::Arc,
1094    };
1095
1096    use client::UserStore;
1097    use clock::FakeSystemClock;
1098    use cloud_llm_client::predict_edits_v3::{self, Point};
1099    use edit_prediction_context::Line;
1100    use futures::{
1101        AsyncReadExt, StreamExt,
1102        channel::{mpsc, oneshot},
1103    };
1104    use gpui::{
1105        Entity, TestAppContext,
1106        http_client::{FakeHttpClient, Response},
1107        prelude::*,
1108    };
1109    use indoc::indoc;
1110    use language::{LanguageServerId, OffsetRangeExt as _};
1111    use pretty_assertions::{assert_eq, assert_matches};
1112    use project::{FakeFs, Project};
1113    use serde_json::json;
1114    use settings::SettingsStore;
1115    use util::path;
1116    use uuid::Uuid;
1117
1118    use crate::{BufferEditPrediction, Zeta};
1119
1120    #[gpui::test]
1121    async fn test_current_state(cx: &mut TestAppContext) {
1122        let (zeta, mut req_rx) = init_test(cx);
1123        let fs = FakeFs::new(cx.executor());
1124        fs.insert_tree(
1125            "/root",
1126            json!({
1127                "1.txt": "Hello!\nHow\nBye",
1128                "2.txt": "Hola!\nComo\nAdios"
1129            }),
1130        )
1131        .await;
1132        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1133
1134        zeta.update(cx, |zeta, cx| {
1135            zeta.register_project(&project, cx);
1136        });
1137
1138        let buffer1 = project
1139            .update(cx, |project, cx| {
1140                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1141                project.open_buffer(path, cx)
1142            })
1143            .await
1144            .unwrap();
1145        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1146        let position = snapshot1.anchor_before(language::Point::new(1, 3));
1147
1148        // Prediction for current file
1149
1150        let prediction_task = zeta.update(cx, |zeta, cx| {
1151            zeta.refresh_prediction(&project, &buffer1, position, cx)
1152        });
1153        let (_request, respond_tx) = req_rx.next().await.unwrap();
1154        respond_tx
1155            .send(predict_edits_v3::PredictEditsResponse {
1156                request_id: Uuid::new_v4(),
1157                edits: vec![predict_edits_v3::Edit {
1158                    path: Path::new(path!("root/1.txt")).into(),
1159                    range: Line(0)..Line(snapshot1.max_point().row + 1),
1160                    content: "Hello!\nHow are you?\nBye".into(),
1161                }],
1162                debug_info: None,
1163            })
1164            .unwrap();
1165        prediction_task.await.unwrap();
1166
1167        zeta.read_with(cx, |zeta, cx| {
1168            let prediction = zeta
1169                .current_prediction_for_buffer(&buffer1, &project, cx)
1170                .unwrap();
1171            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1172        });
1173
1174        // Prediction for another file
1175        let prediction_task = zeta.update(cx, |zeta, cx| {
1176            zeta.refresh_prediction(&project, &buffer1, position, cx)
1177        });
1178        let (_request, respond_tx) = req_rx.next().await.unwrap();
1179        respond_tx
1180            .send(predict_edits_v3::PredictEditsResponse {
1181                request_id: Uuid::new_v4(),
1182                edits: vec![predict_edits_v3::Edit {
1183                    path: Path::new(path!("root/2.txt")).into(),
1184                    range: Line(0)..Line(snapshot1.max_point().row + 1),
1185                    content: "Hola!\nComo estas?\nAdios".into(),
1186                }],
1187                debug_info: None,
1188            })
1189            .unwrap();
1190        prediction_task.await.unwrap();
1191        zeta.read_with(cx, |zeta, cx| {
1192            let prediction = zeta
1193                .current_prediction_for_buffer(&buffer1, &project, cx)
1194                .unwrap();
1195            assert_matches!(
1196                prediction,
1197                BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1198            );
1199        });
1200
1201        let buffer2 = project
1202            .update(cx, |project, cx| {
1203                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1204                project.open_buffer(path, cx)
1205            })
1206            .await
1207            .unwrap();
1208
1209        zeta.read_with(cx, |zeta, cx| {
1210            let prediction = zeta
1211                .current_prediction_for_buffer(&buffer2, &project, cx)
1212                .unwrap();
1213            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1214        });
1215    }
1216
1217    #[gpui::test]
1218    async fn test_simple_request(cx: &mut TestAppContext) {
1219        let (zeta, mut req_rx) = init_test(cx);
1220        let fs = FakeFs::new(cx.executor());
1221        fs.insert_tree(
1222            "/root",
1223            json!({
1224                "foo.md":  "Hello!\nHow\nBye"
1225            }),
1226        )
1227        .await;
1228        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1229
1230        let buffer = project
1231            .update(cx, |project, cx| {
1232                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1233                project.open_buffer(path, cx)
1234            })
1235            .await
1236            .unwrap();
1237        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1238        let position = snapshot.anchor_before(language::Point::new(1, 3));
1239
1240        let prediction_task = zeta.update(cx, |zeta, cx| {
1241            zeta.request_prediction(&project, &buffer, position, cx)
1242        });
1243
1244        let (request, respond_tx) = req_rx.next().await.unwrap();
1245        assert_eq!(
1246            request.excerpt_path.as_ref(),
1247            Path::new(path!("root/foo.md"))
1248        );
1249        assert_eq!(
1250            request.cursor_point,
1251            Point {
1252                line: Line(1),
1253                column: 3
1254            }
1255        );
1256
1257        respond_tx
1258            .send(predict_edits_v3::PredictEditsResponse {
1259                request_id: Uuid::new_v4(),
1260                edits: vec![predict_edits_v3::Edit {
1261                    path: Path::new(path!("root/foo.md")).into(),
1262                    range: Line(0)..Line(snapshot.max_point().row + 1),
1263                    content: "Hello!\nHow are you?\nBye".into(),
1264                }],
1265                debug_info: None,
1266            })
1267            .unwrap();
1268
1269        let prediction = prediction_task.await.unwrap().unwrap();
1270
1271        assert_eq!(prediction.edits.len(), 1);
1272        assert_eq!(
1273            prediction.edits[0].0.to_point(&snapshot).start,
1274            language::Point::new(1, 3)
1275        );
1276        assert_eq!(prediction.edits[0].1, " are you?");
1277    }
1278
1279    #[gpui::test]
1280    async fn test_request_events(cx: &mut TestAppContext) {
1281        let (zeta, mut req_rx) = init_test(cx);
1282        let fs = FakeFs::new(cx.executor());
1283        fs.insert_tree(
1284            "/root",
1285            json!({
1286                "foo.md": "Hello!\n\nBye"
1287            }),
1288        )
1289        .await;
1290        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1291
1292        let buffer = project
1293            .update(cx, |project, cx| {
1294                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1295                project.open_buffer(path, cx)
1296            })
1297            .await
1298            .unwrap();
1299
1300        zeta.update(cx, |zeta, cx| {
1301            zeta.register_buffer(&buffer, &project, cx);
1302        });
1303
1304        buffer.update(cx, |buffer, cx| {
1305            buffer.edit(vec![(7..7, "How")], None, cx);
1306        });
1307
1308        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1309        let position = snapshot.anchor_before(language::Point::new(1, 3));
1310
1311        let prediction_task = zeta.update(cx, |zeta, cx| {
1312            zeta.request_prediction(&project, &buffer, position, cx)
1313        });
1314
1315        let (request, respond_tx) = req_rx.next().await.unwrap();
1316
1317        assert_eq!(request.events.len(), 1);
1318        assert_eq!(
1319            request.events[0],
1320            predict_edits_v3::Event::BufferChange {
1321                path: Some(PathBuf::from(path!("root/foo.md"))),
1322                old_path: None,
1323                diff: indoc! {"
1324                        @@ -1,3 +1,3 @@
1325                         Hello!
1326                        -
1327                        +How
1328                         Bye
1329                    "}
1330                .to_string(),
1331                predicted: false
1332            }
1333        );
1334
1335        respond_tx
1336            .send(predict_edits_v3::PredictEditsResponse {
1337                request_id: Uuid::new_v4(),
1338                edits: vec![predict_edits_v3::Edit {
1339                    path: Path::new(path!("root/foo.md")).into(),
1340                    range: Line(0)..Line(snapshot.max_point().row + 1),
1341                    content: "Hello!\nHow are you?\nBye".into(),
1342                }],
1343                debug_info: None,
1344            })
1345            .unwrap();
1346
1347        let prediction = prediction_task.await.unwrap().unwrap();
1348
1349        assert_eq!(prediction.edits.len(), 1);
1350        assert_eq!(
1351            prediction.edits[0].0.to_point(&snapshot).start,
1352            language::Point::new(1, 3)
1353        );
1354        assert_eq!(prediction.edits[0].1, " are you?");
1355    }
1356
1357    #[gpui::test]
1358    async fn test_request_diagnostics(cx: &mut TestAppContext) {
1359        let (zeta, mut req_rx) = init_test(cx);
1360        let fs = FakeFs::new(cx.executor());
1361        fs.insert_tree(
1362            "/root",
1363            json!({
1364                "foo.md": "Hello!\nBye"
1365            }),
1366        )
1367        .await;
1368        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1369
1370        let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1371        let diagnostic = lsp::Diagnostic {
1372            range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1373            severity: Some(lsp::DiagnosticSeverity::ERROR),
1374            message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1375            ..Default::default()
1376        };
1377
1378        project.update(cx, |project, cx| {
1379            project.lsp_store().update(cx, |lsp_store, cx| {
1380                // Create some diagnostics
1381                lsp_store
1382                    .update_diagnostics(
1383                        LanguageServerId(0),
1384                        lsp::PublishDiagnosticsParams {
1385                            uri: path_to_buffer_uri.clone(),
1386                            diagnostics: vec![diagnostic],
1387                            version: None,
1388                        },
1389                        None,
1390                        language::DiagnosticSourceKind::Pushed,
1391                        &[],
1392                        cx,
1393                    )
1394                    .unwrap();
1395            });
1396        });
1397
1398        let buffer = project
1399            .update(cx, |project, cx| {
1400                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1401                project.open_buffer(path, cx)
1402            })
1403            .await
1404            .unwrap();
1405
1406        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1407        let position = snapshot.anchor_before(language::Point::new(0, 0));
1408
1409        let _prediction_task = zeta.update(cx, |zeta, cx| {
1410            zeta.request_prediction(&project, &buffer, position, cx)
1411        });
1412
1413        let (request, _respond_tx) = req_rx.next().await.unwrap();
1414
1415        assert_eq!(request.diagnostic_groups.len(), 1);
1416        let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1417            .unwrap();
1418        // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1419        assert_eq!(
1420            value,
1421            json!({
1422                "entries": [{
1423                    "range": {
1424                        "start": 8,
1425                        "end": 10
1426                    },
1427                    "diagnostic": {
1428                        "source": null,
1429                        "code": null,
1430                        "code_description": null,
1431                        "severity": 1,
1432                        "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1433                        "markdown": null,
1434                        "group_id": 0,
1435                        "is_primary": true,
1436                        "is_disk_based": false,
1437                        "is_unnecessary": false,
1438                        "source_kind": "Pushed",
1439                        "data": null,
1440                        "underline": true
1441                    }
1442                }],
1443                "primary_ix": 0
1444            })
1445        );
1446    }
1447
1448    fn init_test(
1449        cx: &mut TestAppContext,
1450    ) -> (
1451        Entity<Zeta>,
1452        mpsc::UnboundedReceiver<(
1453            predict_edits_v3::PredictEditsRequest,
1454            oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1455        )>,
1456    ) {
1457        cx.update(move |cx| {
1458            let settings_store = SettingsStore::test(cx);
1459            cx.set_global(settings_store);
1460            language::init(cx);
1461            Project::init_settings(cx);
1462
1463            let (req_tx, req_rx) = mpsc::unbounded();
1464
1465            let http_client = FakeHttpClient::create({
1466                move |req| {
1467                    let uri = req.uri().path().to_string();
1468                    let mut body = req.into_body();
1469                    let req_tx = req_tx.clone();
1470                    async move {
1471                        let resp = match uri.as_str() {
1472                            "/client/llm_tokens" => serde_json::to_string(&json!({
1473                                "token": "test"
1474                            }))
1475                            .unwrap(),
1476                            "/predict_edits/v3" => {
1477                                let mut buf = Vec::new();
1478                                body.read_to_end(&mut buf).await.ok();
1479                                let req = serde_json::from_slice(&buf).unwrap();
1480
1481                                let (res_tx, res_rx) = oneshot::channel();
1482                                req_tx.unbounded_send((req, res_tx)).unwrap();
1483                                serde_json::to_string(&res_rx.await?).unwrap()
1484                            }
1485                            _ => {
1486                                panic!("Unexpected path: {}", uri)
1487                            }
1488                        };
1489
1490                        Ok(Response::builder().body(resp.into()).unwrap())
1491                    }
1492                }
1493            });
1494
1495            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1496            client.cloud_client().set_credentials(1, "test".into());
1497
1498            language_model::init(client.clone(), cx);
1499
1500            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1501            let zeta = Zeta::global(&client, &user_store, cx);
1502
1503            (zeta, req_rx)
1504        })
1505    }
1506}