zeta2.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use chrono::TimeDelta;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
   7    ZED_VERSION_HEADER_NAME,
   8};
   9use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, PlannedPrompt};
  10use edit_prediction_context::{
  11    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  12    EditPredictionExcerptOptions, EditPredictionScoreOptions, SimilarSnippetOptions, SyntaxIndex,
  13    SyntaxIndexState,
  14};
  15use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  16use futures::AsyncReadExt as _;
  17use futures::channel::{mpsc, oneshot};
  18use gpui::http_client::{AsyncBody, Method};
  19use gpui::{
  20    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  21    http_client, prelude::*,
  22};
  23use language::BufferSnapshot;
  24use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
  25use language_model::{LlmApiToken, RefreshLlmTokenListener};
  26use project::Project;
  27use release_channel::AppVersion;
  28use serde::de::DeserializeOwned;
  29use std::collections::{HashMap, VecDeque, hash_map};
  30use std::path::Path;
  31use std::str::FromStr as _;
  32use std::sync::Arc;
  33use std::time::{Duration, Instant};
  34use thiserror::Error;
  35use util::rel_path::RelPathBuf;
  36use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  37
  38mod prediction;
  39mod provider;
  40
  41use crate::prediction::EditPrediction;
  42pub use provider::ZetaEditPredictionProvider;
  43
  44const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  45
  46/// Maximum number of events to track.
  47const MAX_EVENT_COUNT: usize = 16;
  48
  49pub const DEFAULT_CONTEXT_OPTIONS: EditPredictionContextOptions = EditPredictionContextOptions {
  50    use_imports: true,
  51    max_retrieved_declarations: 0,
  52    excerpt: EditPredictionExcerptOptions {
  53        max_bytes: 512,
  54        min_bytes: 128,
  55        target_before_cursor_over_total_bytes: 0.5,
  56    },
  57    score: EditPredictionScoreOptions {
  58        omit_excerpt_overlaps: true,
  59    },
  60    similar_snippets: SimilarSnippetOptions::DEFAULT,
  61};
  62
  63pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  64    context: DEFAULT_CONTEXT_OPTIONS,
  65    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
  66    max_diagnostic_bytes: 2048,
  67    prompt_format: PromptFormat::DEFAULT,
  68    file_indexing_parallelism: 1,
  69};
  70
  71pub struct Zeta2FeatureFlag;
  72
  73impl FeatureFlag for Zeta2FeatureFlag {
  74    const NAME: &'static str = "zeta2";
  75
  76    fn enabled_for_staff() -> bool {
  77        false
  78    }
  79}
  80
  81#[derive(Clone)]
  82struct ZetaGlobal(Entity<Zeta>);
  83
  84impl Global for ZetaGlobal {}
  85
  86pub struct Zeta {
  87    client: Arc<Client>,
  88    user_store: Entity<UserStore>,
  89    llm_token: LlmApiToken,
  90    _llm_token_subscription: Subscription,
  91    projects: HashMap<EntityId, ZetaProject>,
  92    options: ZetaOptions,
  93    update_required: bool,
  94    debug_tx: Option<mpsc::UnboundedSender<PredictionDebugInfo>>,
  95}
  96
  97#[derive(Debug, Clone, PartialEq)]
  98pub struct ZetaOptions {
  99    pub context: EditPredictionContextOptions,
 100    pub max_prompt_bytes: usize,
 101    pub max_diagnostic_bytes: usize,
 102    pub prompt_format: predict_edits_v3::PromptFormat,
 103    pub file_indexing_parallelism: usize,
 104}
 105
 106pub struct PredictionDebugInfo {
 107    pub request: predict_edits_v3::PredictEditsRequest,
 108    pub retrieval_time: TimeDelta,
 109    pub buffer: WeakEntity<Buffer>,
 110    pub position: language::Anchor,
 111    pub local_prompt: Result<String, String>,
 112    pub response_rx: oneshot::Receiver<Result<predict_edits_v3::PredictEditsResponse, String>>,
 113}
 114
 115pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 116
 117struct ZetaProject {
 118    syntax_index: Entity<SyntaxIndex>,
 119    events: VecDeque<Event>,
 120    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 121    current_prediction: Option<CurrentEditPrediction>,
 122}
 123
 124#[derive(Debug, Clone)]
 125struct CurrentEditPrediction {
 126    pub requested_by_buffer_id: EntityId,
 127    pub prediction: EditPrediction,
 128}
 129
 130impl CurrentEditPrediction {
 131    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 132        let Some(new_edits) = self
 133            .prediction
 134            .interpolate(&self.prediction.buffer.read(cx))
 135        else {
 136            return false;
 137        };
 138
 139        if self.prediction.buffer != old_prediction.prediction.buffer {
 140            return true;
 141        }
 142
 143        let Some(old_edits) = old_prediction
 144            .prediction
 145            .interpolate(&old_prediction.prediction.buffer.read(cx))
 146        else {
 147            return true;
 148        };
 149
 150        // This reduces the occurrence of UI thrash from replacing edits
 151        //
 152        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 153        if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
 154            && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
 155            && old_edits.len() == 1
 156            && new_edits.len() == 1
 157        {
 158            let (old_range, old_text) = &old_edits[0];
 159            let (new_range, new_text) = &new_edits[0];
 160            new_range == old_range && new_text.starts_with(old_text)
 161        } else {
 162            true
 163        }
 164    }
 165}
 166
 167/// A prediction from the perspective of a buffer.
 168#[derive(Debug)]
 169enum BufferEditPrediction<'a> {
 170    Local { prediction: &'a EditPrediction },
 171    Jump { prediction: &'a EditPrediction },
 172}
 173
 174struct RegisteredBuffer {
 175    snapshot: BufferSnapshot,
 176    _subscriptions: [gpui::Subscription; 2],
 177}
 178
 179#[derive(Clone)]
 180pub enum Event {
 181    BufferChange {
 182        old_snapshot: BufferSnapshot,
 183        new_snapshot: BufferSnapshot,
 184        timestamp: Instant,
 185    },
 186}
 187
 188impl Zeta {
 189    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 190        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 191    }
 192
 193    pub fn global(
 194        client: &Arc<Client>,
 195        user_store: &Entity<UserStore>,
 196        cx: &mut App,
 197    ) -> Entity<Self> {
 198        cx.try_global::<ZetaGlobal>()
 199            .map(|global| global.0.clone())
 200            .unwrap_or_else(|| {
 201                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 202                cx.set_global(ZetaGlobal(zeta.clone()));
 203                zeta
 204            })
 205    }
 206
 207    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 208        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 209
 210        Self {
 211            projects: HashMap::new(),
 212            client,
 213            user_store,
 214            options: DEFAULT_OPTIONS,
 215            llm_token: LlmApiToken::default(),
 216            _llm_token_subscription: cx.subscribe(
 217                &refresh_llm_token_listener,
 218                |this, _listener, _event, cx| {
 219                    let client = this.client.clone();
 220                    let llm_token = this.llm_token.clone();
 221                    cx.spawn(async move |_this, _cx| {
 222                        llm_token.refresh(&client).await?;
 223                        anyhow::Ok(())
 224                    })
 225                    .detach_and_log_err(cx);
 226                },
 227            ),
 228            update_required: false,
 229            debug_tx: None,
 230        }
 231    }
 232
 233    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<PredictionDebugInfo> {
 234        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 235        self.debug_tx = Some(debug_watch_tx);
 236        debug_watch_rx
 237    }
 238
 239    pub fn options(&self) -> &ZetaOptions {
 240        &self.options
 241    }
 242
 243    pub fn set_options(&mut self, options: ZetaOptions) {
 244        self.options = options;
 245    }
 246
 247    pub fn clear_history(&mut self) {
 248        for zeta_project in self.projects.values_mut() {
 249            zeta_project.events.clear();
 250        }
 251    }
 252
 253    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 254        self.user_store.read(cx).edit_prediction_usage()
 255    }
 256
 257    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 258        self.get_or_init_zeta_project(project, cx);
 259    }
 260
 261    pub fn register_buffer(
 262        &mut self,
 263        buffer: &Entity<Buffer>,
 264        project: &Entity<Project>,
 265        cx: &mut Context<Self>,
 266    ) {
 267        let zeta_project = self.get_or_init_zeta_project(project, cx);
 268        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 269    }
 270
 271    fn get_or_init_zeta_project(
 272        &mut self,
 273        project: &Entity<Project>,
 274        cx: &mut App,
 275    ) -> &mut ZetaProject {
 276        self.projects
 277            .entry(project.entity_id())
 278            .or_insert_with(|| ZetaProject {
 279                syntax_index: cx.new(|cx| {
 280                    SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 281                }),
 282                events: VecDeque::new(),
 283                registered_buffers: HashMap::new(),
 284                current_prediction: None,
 285            })
 286    }
 287
 288    fn register_buffer_impl<'a>(
 289        zeta_project: &'a mut ZetaProject,
 290        buffer: &Entity<Buffer>,
 291        project: &Entity<Project>,
 292        cx: &mut Context<Self>,
 293    ) -> &'a mut RegisteredBuffer {
 294        let buffer_id = buffer.entity_id();
 295        match zeta_project.registered_buffers.entry(buffer_id) {
 296            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 297            hash_map::Entry::Vacant(entry) => {
 298                let snapshot = buffer.read(cx).snapshot();
 299                let project_entity_id = project.entity_id();
 300                entry.insert(RegisteredBuffer {
 301                    snapshot,
 302                    _subscriptions: [
 303                        cx.subscribe(buffer, {
 304                            let project = project.downgrade();
 305                            move |this, buffer, event, cx| {
 306                                if let language::BufferEvent::Edited = event
 307                                    && let Some(project) = project.upgrade()
 308                                {
 309                                    this.report_changes_for_buffer(&buffer, &project, cx);
 310                                }
 311                            }
 312                        }),
 313                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 314                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 315                            else {
 316                                return;
 317                            };
 318                            zeta_project.registered_buffers.remove(&buffer_id);
 319                        }),
 320                    ],
 321                })
 322            }
 323        }
 324    }
 325
 326    fn report_changes_for_buffer(
 327        &mut self,
 328        buffer: &Entity<Buffer>,
 329        project: &Entity<Project>,
 330        cx: &mut Context<Self>,
 331    ) -> BufferSnapshot {
 332        let zeta_project = self.get_or_init_zeta_project(project, cx);
 333        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 334
 335        let new_snapshot = buffer.read(cx).snapshot();
 336        if new_snapshot.version != registered_buffer.snapshot.version {
 337            let old_snapshot =
 338                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 339            Self::push_event(
 340                zeta_project,
 341                Event::BufferChange {
 342                    old_snapshot,
 343                    new_snapshot: new_snapshot.clone(),
 344                    timestamp: Instant::now(),
 345                },
 346            );
 347        }
 348
 349        new_snapshot
 350    }
 351
 352    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
 353        let events = &mut zeta_project.events;
 354
 355        if let Some(Event::BufferChange {
 356            new_snapshot: last_new_snapshot,
 357            timestamp: last_timestamp,
 358            ..
 359        }) = events.back_mut()
 360        {
 361            // Coalesce edits for the same buffer when they happen one after the other.
 362            let Event::BufferChange {
 363                old_snapshot,
 364                new_snapshot,
 365                timestamp,
 366            } = &event;
 367
 368            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 369                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 370                && old_snapshot.version == last_new_snapshot.version
 371            {
 372                *last_new_snapshot = new_snapshot.clone();
 373                *last_timestamp = *timestamp;
 374                return;
 375            }
 376        }
 377
 378        if events.len() >= MAX_EVENT_COUNT {
 379            // These are halved instead of popping to improve prompt caching.
 380            events.drain(..MAX_EVENT_COUNT / 2);
 381        }
 382
 383        events.push_back(event);
 384    }
 385
 386    fn current_prediction_for_buffer(
 387        &self,
 388        buffer: &Entity<Buffer>,
 389        project: &Entity<Project>,
 390        cx: &App,
 391    ) -> Option<BufferEditPrediction<'_>> {
 392        let project_state = self.projects.get(&project.entity_id())?;
 393
 394        let CurrentEditPrediction {
 395            requested_by_buffer_id,
 396            prediction,
 397        } = project_state.current_prediction.as_ref()?;
 398
 399        if prediction.targets_buffer(buffer.read(cx), cx) {
 400            Some(BufferEditPrediction::Local { prediction })
 401        } else if *requested_by_buffer_id == buffer.entity_id() {
 402            Some(BufferEditPrediction::Jump { prediction })
 403        } else {
 404            None
 405        }
 406    }
 407
 408    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 409        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 410            return;
 411        };
 412
 413        let Some(prediction) = project_state.current_prediction.take() else {
 414            return;
 415        };
 416        let request_id = prediction.prediction.id.into();
 417
 418        let client = self.client.clone();
 419        let llm_token = self.llm_token.clone();
 420        let app_version = AppVersion::global(cx);
 421        cx.spawn(async move |this, cx| {
 422            let url = if let Ok(predict_edits_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
 423                http_client::Url::parse(&predict_edits_url)?
 424            } else {
 425                client
 426                    .http_client()
 427                    .build_zed_llm_url("/predict_edits/accept", &[])?
 428            };
 429
 430            let response = cx
 431                .background_spawn(Self::send_api_request::<()>(
 432                    move |builder| {
 433                        let req = builder.uri(url.as_ref()).body(
 434                            serde_json::to_string(&AcceptEditPredictionBody { request_id })?.into(),
 435                        );
 436                        Ok(req?)
 437                    },
 438                    client,
 439                    llm_token,
 440                    app_version,
 441                ))
 442                .await;
 443
 444            Self::handle_api_response(&this, response, cx)?;
 445            anyhow::Ok(())
 446        })
 447        .detach_and_log_err(cx);
 448    }
 449
 450    fn discard_current_prediction(&mut self, project: &Entity<Project>) {
 451        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 452            project_state.current_prediction.take();
 453        };
 454    }
 455
 456    pub fn refresh_prediction(
 457        &mut self,
 458        project: &Entity<Project>,
 459        buffer: &Entity<Buffer>,
 460        position: language::Anchor,
 461        cx: &mut Context<Self>,
 462    ) -> Task<Result<()>> {
 463        let request_task = self.request_prediction(project, buffer, position, cx);
 464        let buffer = buffer.clone();
 465        let project = project.clone();
 466
 467        cx.spawn(async move |this, cx| {
 468            if let Some(prediction) = request_task.await? {
 469                this.update(cx, |this, cx| {
 470                    let project_state = this
 471                        .projects
 472                        .get_mut(&project.entity_id())
 473                        .context("Project not found")?;
 474
 475                    let new_prediction = CurrentEditPrediction {
 476                        requested_by_buffer_id: buffer.entity_id(),
 477                        prediction: prediction,
 478                    };
 479
 480                    if project_state
 481                        .current_prediction
 482                        .as_ref()
 483                        .is_none_or(|old_prediction| {
 484                            new_prediction.should_replace_prediction(&old_prediction, cx)
 485                        })
 486                    {
 487                        project_state.current_prediction = Some(new_prediction);
 488                    }
 489                    anyhow::Ok(())
 490                })??;
 491            }
 492            Ok(())
 493        })
 494    }
 495
 496    fn request_prediction(
 497        &mut self,
 498        project: &Entity<Project>,
 499        buffer: &Entity<Buffer>,
 500        position: language::Anchor,
 501        cx: &mut Context<Self>,
 502    ) -> Task<Result<Option<EditPrediction>>> {
 503        let project_state = self.projects.get(&project.entity_id());
 504
 505        let index_state = project_state.map(|state| {
 506            state
 507                .syntax_index
 508                .read_with(cx, |index, _cx| index.state().clone())
 509        });
 510        let options = self.options.clone();
 511        let snapshot = buffer.read(cx).snapshot();
 512        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx).into()) else {
 513            return Task::ready(Err(anyhow!("No file path for excerpt")));
 514        };
 515        let client = self.client.clone();
 516        let llm_token = self.llm_token.clone();
 517        let app_version = AppVersion::global(cx);
 518        let worktree_snapshots = project
 519            .read(cx)
 520            .worktrees(cx)
 521            .map(|worktree| worktree.read(cx).snapshot())
 522            .collect::<Vec<_>>();
 523        let debug_tx = self.debug_tx.clone();
 524
 525        let events = project_state
 526            .map(|state| {
 527                state
 528                    .events
 529                    .iter()
 530                    .filter_map(|event| match event {
 531                        Event::BufferChange {
 532                            old_snapshot,
 533                            new_snapshot,
 534                            ..
 535                        } => {
 536                            let path = new_snapshot.file().map(|f| f.full_path(cx));
 537
 538                            let old_path = old_snapshot.file().and_then(|f| {
 539                                let old_path = f.full_path(cx);
 540                                if Some(&old_path) != path.as_ref() {
 541                                    Some(old_path)
 542                                } else {
 543                                    None
 544                                }
 545                            });
 546
 547                            // TODO [zeta2] move to bg?
 548                            let diff =
 549                                language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
 550
 551                            if path == old_path && diff.is_empty() {
 552                                None
 553                            } else {
 554                                Some(predict_edits_v3::Event::BufferChange {
 555                                    old_path,
 556                                    path,
 557                                    diff,
 558                                    //todo: Actually detect if this edit was predicted or not
 559                                    predicted: false,
 560                                })
 561                            }
 562                        }
 563                    })
 564                    .collect::<Vec<_>>()
 565            })
 566            .unwrap_or_default();
 567
 568        let diagnostics = snapshot.diagnostic_sets().clone();
 569
 570        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
 571            let mut path = f.worktree.read(cx).absolutize(&f.path);
 572            if path.pop() { Some(path) } else { None }
 573        });
 574
 575        // TODO data collection
 576        let can_collect_data = cx.is_staff();
 577
 578        let request_task = cx.background_spawn({
 579            let snapshot = snapshot.clone();
 580            let buffer = buffer.clone();
 581            async move {
 582                let index_state = if let Some(index_state) = index_state {
 583                    Some(index_state.lock_owned().await)
 584                } else {
 585                    None
 586                };
 587
 588                let cursor_offset = position.to_offset(&snapshot);
 589                let cursor_point = cursor_offset.to_point(&snapshot);
 590
 591                let before_retrieval = chrono::Utc::now();
 592
 593                let Some(context) = EditPredictionContext::gather_context(
 594                    cursor_point,
 595                    &snapshot,
 596                    parent_abs_path.as_deref(),
 597                    &options.context,
 598                    index_state.as_deref(),
 599                ) else {
 600                    return Ok((None, None));
 601                };
 602
 603                let retrieval_time = chrono::Utc::now() - before_retrieval;
 604
 605                let (diagnostic_groups, diagnostic_groups_truncated) =
 606                    Self::gather_nearby_diagnostics(
 607                        cursor_offset,
 608                        &diagnostics,
 609                        &snapshot,
 610                        options.max_diagnostic_bytes,
 611                    );
 612
 613                let request = make_cloud_request(
 614                    excerpt_path,
 615                    context,
 616                    events,
 617                    can_collect_data,
 618                    diagnostic_groups,
 619                    diagnostic_groups_truncated,
 620                    None,
 621                    debug_tx.is_some(),
 622                    &worktree_snapshots,
 623                    index_state.as_deref(),
 624                    Some(options.max_prompt_bytes),
 625                    options.prompt_format,
 626                );
 627
 628                let debug_response_tx = if let Some(debug_tx) = &debug_tx {
 629                    let (response_tx, response_rx) = oneshot::channel();
 630
 631                    let local_prompt = PlannedPrompt::populate(&request)
 632                        .and_then(|p| p.to_prompt_string().map(|p| p.0))
 633                        .map_err(|err| err.to_string());
 634
 635                    debug_tx
 636                        .unbounded_send(PredictionDebugInfo {
 637                            request: request.clone(),
 638                            retrieval_time,
 639                            buffer: buffer.downgrade(),
 640                            local_prompt,
 641                            position,
 642                            response_rx,
 643                        })
 644                        .ok();
 645                    Some(response_tx)
 646                } else {
 647                    None
 648                };
 649
 650                if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
 651                    if let Some(debug_response_tx) = debug_response_tx {
 652                        debug_response_tx
 653                            .send(Err("Request skipped".to_string()))
 654                            .ok();
 655                    }
 656                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
 657                }
 658
 659                let response =
 660                    Self::send_prediction_request(client, llm_token, app_version, request).await;
 661
 662                if let Some(debug_response_tx) = debug_response_tx {
 663                    debug_response_tx
 664                        .send(
 665                            response
 666                                .as_ref()
 667                                .map_err(|err| err.to_string())
 668                                .map(|response| response.0.clone()),
 669                        )
 670                        .ok();
 671                }
 672
 673                response.map(|(res, usage)| (Some(res), usage))
 674            }
 675        });
 676
 677        let buffer = buffer.clone();
 678
 679        cx.spawn({
 680            let project = project.clone();
 681            async move |this, cx| {
 682                let Some(response) = Self::handle_api_response(&this, request_task.await, cx)?
 683                else {
 684                    return Ok(None);
 685                };
 686
 687                // TODO telemetry: duration, etc
 688                Ok(EditPrediction::from_response(response, &snapshot, &buffer, &project, cx).await)
 689            }
 690        })
 691    }
 692
 693    async fn send_prediction_request(
 694        client: Arc<Client>,
 695        llm_token: LlmApiToken,
 696        app_version: SemanticVersion,
 697        request: predict_edits_v3::PredictEditsRequest,
 698    ) -> Result<(
 699        predict_edits_v3::PredictEditsResponse,
 700        Option<EditPredictionUsage>,
 701    )> {
 702        let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 703            http_client::Url::parse(&predict_edits_url)?
 704        } else {
 705            client
 706                .http_client()
 707                .build_zed_llm_url("/predict_edits/v3", &[])?
 708        };
 709
 710        Self::send_api_request(
 711            |builder| {
 712                let req = builder
 713                    .uri(url.as_ref())
 714                    .body(serde_json::to_string(&request)?.into());
 715                Ok(req?)
 716            },
 717            client,
 718            llm_token,
 719            app_version,
 720        )
 721        .await
 722    }
 723
 724    fn handle_api_response<T>(
 725        this: &WeakEntity<Self>,
 726        response: Result<(T, Option<EditPredictionUsage>)>,
 727        cx: &mut gpui::AsyncApp,
 728    ) -> Result<T> {
 729        match response {
 730            Ok((data, usage)) => {
 731                if let Some(usage) = usage {
 732                    this.update(cx, |this, cx| {
 733                        this.user_store.update(cx, |user_store, cx| {
 734                            user_store.update_edit_prediction_usage(usage, cx);
 735                        });
 736                    })
 737                    .ok();
 738                }
 739                Ok(data)
 740            }
 741            Err(err) => {
 742                if err.is::<ZedUpdateRequiredError>() {
 743                    cx.update(|cx| {
 744                        this.update(cx, |this, _cx| {
 745                            this.update_required = true;
 746                        })
 747                        .ok();
 748
 749                        let error_message: SharedString = err.to_string().into();
 750                        show_app_notification(
 751                            NotificationId::unique::<ZedUpdateRequiredError>(),
 752                            cx,
 753                            move |cx| {
 754                                cx.new(|cx| {
 755                                    ErrorMessagePrompt::new(error_message.clone(), cx)
 756                                        .with_link_button("Update Zed", "https://zed.dev/releases")
 757                                })
 758                            },
 759                        );
 760                    })
 761                    .ok();
 762                }
 763                Err(err)
 764            }
 765        }
 766    }
 767
 768    async fn send_api_request<Res>(
 769        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
 770        client: Arc<Client>,
 771        llm_token: LlmApiToken,
 772        app_version: SemanticVersion,
 773    ) -> Result<(Res, Option<EditPredictionUsage>)>
 774    where
 775        Res: DeserializeOwned,
 776    {
 777        let http_client = client.http_client();
 778        let mut token = llm_token.acquire(&client).await?;
 779        let mut did_retry = false;
 780
 781        loop {
 782            let request_builder = http_client::Request::builder().method(Method::POST);
 783
 784            let request = build(
 785                request_builder
 786                    .header("Content-Type", "application/json")
 787                    .header("Authorization", format!("Bearer {}", token))
 788                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
 789            )?;
 790
 791            let mut response = http_client.send(request).await?;
 792
 793            if let Some(minimum_required_version) = response
 794                .headers()
 795                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 796                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 797            {
 798                anyhow::ensure!(
 799                    app_version >= minimum_required_version,
 800                    ZedUpdateRequiredError {
 801                        minimum_version: minimum_required_version
 802                    }
 803                );
 804            }
 805
 806            if response.status().is_success() {
 807                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 808
 809                let mut body = Vec::new();
 810                response.body_mut().read_to_end(&mut body).await?;
 811                return Ok((serde_json::from_slice(&body)?, usage));
 812            } else if !did_retry
 813                && response
 814                    .headers()
 815                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 816                    .is_some()
 817            {
 818                did_retry = true;
 819                token = llm_token.refresh(&client).await?;
 820            } else {
 821                let mut body = String::new();
 822                response.body_mut().read_to_string(&mut body).await?;
 823                anyhow::bail!(
 824                    "Request failed with status: {:?}\nBody: {}",
 825                    response.status(),
 826                    body
 827                );
 828            }
 829        }
 830    }
 831
 832    fn gather_nearby_diagnostics(
 833        cursor_offset: usize,
 834        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
 835        snapshot: &BufferSnapshot,
 836        max_diagnostics_bytes: usize,
 837    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
 838        // TODO: Could make this more efficient
 839        let mut diagnostic_groups = Vec::new();
 840        for (language_server_id, diagnostics) in diagnostic_sets {
 841            let mut groups = Vec::new();
 842            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
 843            diagnostic_groups.extend(
 844                groups
 845                    .into_iter()
 846                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
 847            );
 848        }
 849
 850        // sort by proximity to cursor
 851        diagnostic_groups.sort_by_key(|group| {
 852            let range = &group.entries[group.primary_ix].range;
 853            if range.start >= cursor_offset {
 854                range.start - cursor_offset
 855            } else if cursor_offset >= range.end {
 856                cursor_offset - range.end
 857            } else {
 858                (cursor_offset - range.start).min(range.end - cursor_offset)
 859            }
 860        });
 861
 862        let mut results = Vec::new();
 863        let mut diagnostic_groups_truncated = false;
 864        let mut diagnostics_byte_count = 0;
 865        for group in diagnostic_groups {
 866            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
 867            diagnostics_byte_count += raw_value.get().len();
 868            if diagnostics_byte_count > max_diagnostics_bytes {
 869                diagnostic_groups_truncated = true;
 870                break;
 871            }
 872            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
 873        }
 874
 875        (results, diagnostic_groups_truncated)
 876    }
 877
 878    // TODO: Dedupe with similar code in request_prediction?
 879    pub fn cloud_request_for_zeta_cli(
 880        &mut self,
 881        project: &Entity<Project>,
 882        buffer: &Entity<Buffer>,
 883        position: language::Anchor,
 884        cx: &mut Context<Self>,
 885    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
 886        let project_state = self.projects.get(&project.entity_id());
 887
 888        let index_state = project_state.map(|state| {
 889            state
 890                .syntax_index
 891                .read_with(cx, |index, _cx| index.state().clone())
 892        });
 893        let options = self.options.clone();
 894        let snapshot = buffer.read(cx).snapshot();
 895        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
 896            return Task::ready(Err(anyhow!("No file path for excerpt")));
 897        };
 898        let worktree_snapshots = project
 899            .read(cx)
 900            .worktrees(cx)
 901            .map(|worktree| worktree.read(cx).snapshot())
 902            .collect::<Vec<_>>();
 903
 904        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
 905            let mut path = f.worktree.read(cx).absolutize(&f.path);
 906            if path.pop() { Some(path) } else { None }
 907        });
 908
 909        cx.background_spawn(async move {
 910            let index_state = if let Some(index_state) = index_state {
 911                Some(index_state.lock_owned().await)
 912            } else {
 913                None
 914            };
 915
 916            let cursor_point = position.to_point(&snapshot);
 917
 918            let debug_info = true;
 919            EditPredictionContext::gather_context(
 920                cursor_point,
 921                &snapshot,
 922                parent_abs_path.as_deref(),
 923                &options.context,
 924                index_state.as_deref(),
 925            )
 926            .context("Failed to select excerpt")
 927            .map(|context| {
 928                make_cloud_request(
 929                    excerpt_path.into(),
 930                    context,
 931                    // TODO pass everything
 932                    Vec::new(),
 933                    false,
 934                    Vec::new(),
 935                    false,
 936                    None,
 937                    debug_info,
 938                    &worktree_snapshots,
 939                    index_state.as_deref(),
 940                    Some(options.max_prompt_bytes),
 941                    options.prompt_format,
 942                )
 943            })
 944        })
 945    }
 946
 947    pub fn wait_for_initial_indexing(
 948        &mut self,
 949        project: &Entity<Project>,
 950        cx: &mut App,
 951    ) -> Task<Result<()>> {
 952        let zeta_project = self.get_or_init_zeta_project(project, cx);
 953        zeta_project
 954            .syntax_index
 955            .read(cx)
 956            .wait_for_initial_file_indexing(cx)
 957    }
 958}
 959
 960#[derive(Error, Debug)]
 961#[error(
 962    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
 963)]
 964pub struct ZedUpdateRequiredError {
 965    minimum_version: SemanticVersion,
 966}
 967
 968fn make_cloud_request(
 969    excerpt_path: Arc<Path>,
 970    context: EditPredictionContext,
 971    events: Vec<predict_edits_v3::Event>,
 972    can_collect_data: bool,
 973    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
 974    diagnostic_groups_truncated: bool,
 975    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
 976    debug_info: bool,
 977    worktrees: &Vec<worktree::Snapshot>,
 978    index_state: Option<&SyntaxIndexState>,
 979    prompt_max_bytes: Option<usize>,
 980    prompt_format: PromptFormat,
 981) -> predict_edits_v3::PredictEditsRequest {
 982    let mut signatures = Vec::new();
 983    let mut declaration_to_signature_index = HashMap::default();
 984    let mut referenced_declarations = Vec::new();
 985
 986    for snippet in context.declarations {
 987        let project_entry_id = snippet.declaration.project_entry_id();
 988        let Some(path) = worktrees.iter().find_map(|worktree| {
 989            worktree.entry_for_id(project_entry_id).map(|entry| {
 990                let mut full_path = RelPathBuf::new();
 991                full_path.push(worktree.root_name());
 992                full_path.push(&entry.path);
 993                full_path
 994            })
 995        }) else {
 996            continue;
 997        };
 998
 999        let parent_index = index_state.and_then(|index_state| {
1000            snippet.declaration.parent().and_then(|parent| {
1001                add_signature(
1002                    parent,
1003                    &mut declaration_to_signature_index,
1004                    &mut signatures,
1005                    index_state,
1006                )
1007            })
1008        });
1009
1010        let (text, text_is_truncated) = snippet.declaration.item_text();
1011        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1012            path: path.as_std_path().into(),
1013            text: text.into(),
1014            range: snippet.declaration.item_line_range(),
1015            text_is_truncated,
1016            signature_range: snippet.declaration.signature_range_in_item_text(),
1017            parent_index,
1018            signature_score: snippet.score(DeclarationStyle::Signature),
1019            declaration_score: snippet.score(DeclarationStyle::Declaration),
1020            score_components: snippet.components,
1021        });
1022    }
1023
1024    let excerpt_parent = index_state.and_then(|index_state| {
1025        context
1026            .excerpt
1027            .parent_declarations
1028            .last()
1029            .and_then(|(parent, _)| {
1030                add_signature(
1031                    *parent,
1032                    &mut declaration_to_signature_index,
1033                    &mut signatures,
1034                    index_state,
1035                )
1036            })
1037    });
1038
1039    predict_edits_v3::PredictEditsRequest {
1040        excerpt_path,
1041        excerpt: context.excerpt_text.body,
1042        excerpt_line_range: context.excerpt.line_range,
1043        excerpt_range: context.excerpt.range,
1044        cursor_point: predict_edits_v3::Point {
1045            line: predict_edits_v3::Line(context.cursor_point.row),
1046            column: context.cursor_point.column,
1047        },
1048        referenced_declarations,
1049        signatures,
1050        excerpt_parent,
1051        events,
1052        can_collect_data,
1053        diagnostic_groups,
1054        diagnostic_groups_truncated,
1055        git_info,
1056        debug_info,
1057        prompt_max_bytes,
1058        prompt_format,
1059    }
1060}
1061
1062fn add_signature(
1063    declaration_id: DeclarationId,
1064    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1065    signatures: &mut Vec<Signature>,
1066    index: &SyntaxIndexState,
1067) -> Option<usize> {
1068    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1069        return Some(*signature_index);
1070    }
1071    let Some(parent_declaration) = index.declaration(declaration_id) else {
1072        log::error!("bug: missing parent declaration");
1073        return None;
1074    };
1075    let parent_index = parent_declaration.parent().and_then(|parent| {
1076        add_signature(parent, declaration_to_signature_index, signatures, index)
1077    });
1078    let (text, text_is_truncated) = parent_declaration.signature_text();
1079    let signature_index = signatures.len();
1080    signatures.push(Signature {
1081        text: text.into(),
1082        text_is_truncated,
1083        parent_index,
1084        range: parent_declaration.signature_line_range(),
1085    });
1086    declaration_to_signature_index.insert(declaration_id, signature_index);
1087    Some(signature_index)
1088}
1089
1090#[cfg(test)]
1091mod tests {
1092    use std::{
1093        path::{Path, PathBuf},
1094        sync::Arc,
1095    };
1096
1097    use client::UserStore;
1098    use clock::FakeSystemClock;
1099    use cloud_llm_client::predict_edits_v3::{self, Point};
1100    use edit_prediction_context::Line;
1101    use futures::{
1102        AsyncReadExt, StreamExt,
1103        channel::{mpsc, oneshot},
1104    };
1105    use gpui::{
1106        Entity, TestAppContext,
1107        http_client::{FakeHttpClient, Response},
1108        prelude::*,
1109    };
1110    use indoc::indoc;
1111    use language::{LanguageServerId, OffsetRangeExt as _};
1112    use pretty_assertions::{assert_eq, assert_matches};
1113    use project::{FakeFs, Project};
1114    use serde_json::json;
1115    use settings::SettingsStore;
1116    use util::path;
1117    use uuid::Uuid;
1118
1119    use crate::{BufferEditPrediction, Zeta};
1120
1121    #[gpui::test]
1122    async fn test_current_state(cx: &mut TestAppContext) {
1123        let (zeta, mut req_rx) = init_test(cx);
1124        let fs = FakeFs::new(cx.executor());
1125        fs.insert_tree(
1126            "/root",
1127            json!({
1128                "1.txt": "Hello!\nHow\nBye",
1129                "2.txt": "Hola!\nComo\nAdios"
1130            }),
1131        )
1132        .await;
1133        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1134
1135        zeta.update(cx, |zeta, cx| {
1136            zeta.register_project(&project, cx);
1137        });
1138
1139        let buffer1 = project
1140            .update(cx, |project, cx| {
1141                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1142                project.open_buffer(path, cx)
1143            })
1144            .await
1145            .unwrap();
1146        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1147        let position = snapshot1.anchor_before(language::Point::new(1, 3));
1148
1149        // Prediction for current file
1150
1151        let prediction_task = zeta.update(cx, |zeta, cx| {
1152            zeta.refresh_prediction(&project, &buffer1, position, cx)
1153        });
1154        let (_request, respond_tx) = req_rx.next().await.unwrap();
1155        respond_tx
1156            .send(predict_edits_v3::PredictEditsResponse {
1157                request_id: Uuid::new_v4(),
1158                edits: vec![predict_edits_v3::Edit {
1159                    path: Path::new(path!("root/1.txt")).into(),
1160                    range: Line(0)..Line(snapshot1.max_point().row + 1),
1161                    content: "Hello!\nHow are you?\nBye".into(),
1162                }],
1163                debug_info: None,
1164            })
1165            .unwrap();
1166        prediction_task.await.unwrap();
1167
1168        zeta.read_with(cx, |zeta, cx| {
1169            let prediction = zeta
1170                .current_prediction_for_buffer(&buffer1, &project, cx)
1171                .unwrap();
1172            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1173        });
1174
1175        // Prediction for another file
1176        let prediction_task = zeta.update(cx, |zeta, cx| {
1177            zeta.refresh_prediction(&project, &buffer1, position, cx)
1178        });
1179        let (_request, respond_tx) = req_rx.next().await.unwrap();
1180        respond_tx
1181            .send(predict_edits_v3::PredictEditsResponse {
1182                request_id: Uuid::new_v4(),
1183                edits: vec![predict_edits_v3::Edit {
1184                    path: Path::new(path!("root/2.txt")).into(),
1185                    range: Line(0)..Line(snapshot1.max_point().row + 1),
1186                    content: "Hola!\nComo estas?\nAdios".into(),
1187                }],
1188                debug_info: None,
1189            })
1190            .unwrap();
1191        prediction_task.await.unwrap();
1192        zeta.read_with(cx, |zeta, cx| {
1193            let prediction = zeta
1194                .current_prediction_for_buffer(&buffer1, &project, cx)
1195                .unwrap();
1196            assert_matches!(
1197                prediction,
1198                BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1199            );
1200        });
1201
1202        let buffer2 = project
1203            .update(cx, |project, cx| {
1204                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1205                project.open_buffer(path, cx)
1206            })
1207            .await
1208            .unwrap();
1209
1210        zeta.read_with(cx, |zeta, cx| {
1211            let prediction = zeta
1212                .current_prediction_for_buffer(&buffer2, &project, cx)
1213                .unwrap();
1214            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1215        });
1216    }
1217
1218    #[gpui::test]
1219    async fn test_simple_request(cx: &mut TestAppContext) {
1220        let (zeta, mut req_rx) = init_test(cx);
1221        let fs = FakeFs::new(cx.executor());
1222        fs.insert_tree(
1223            "/root",
1224            json!({
1225                "foo.md":  "Hello!\nHow\nBye"
1226            }),
1227        )
1228        .await;
1229        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1230
1231        let buffer = project
1232            .update(cx, |project, cx| {
1233                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1234                project.open_buffer(path, cx)
1235            })
1236            .await
1237            .unwrap();
1238        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1239        let position = snapshot.anchor_before(language::Point::new(1, 3));
1240
1241        let prediction_task = zeta.update(cx, |zeta, cx| {
1242            zeta.request_prediction(&project, &buffer, position, cx)
1243        });
1244
1245        let (request, respond_tx) = req_rx.next().await.unwrap();
1246        assert_eq!(
1247            request.excerpt_path.as_ref(),
1248            Path::new(path!("root/foo.md"))
1249        );
1250        assert_eq!(
1251            request.cursor_point,
1252            Point {
1253                line: Line(1),
1254                column: 3
1255            }
1256        );
1257
1258        respond_tx
1259            .send(predict_edits_v3::PredictEditsResponse {
1260                request_id: Uuid::new_v4(),
1261                edits: vec![predict_edits_v3::Edit {
1262                    path: Path::new(path!("root/foo.md")).into(),
1263                    range: Line(0)..Line(snapshot.max_point().row + 1),
1264                    content: "Hello!\nHow are you?\nBye".into(),
1265                }],
1266                debug_info: None,
1267            })
1268            .unwrap();
1269
1270        let prediction = prediction_task.await.unwrap().unwrap();
1271
1272        assert_eq!(prediction.edits.len(), 1);
1273        assert_eq!(
1274            prediction.edits[0].0.to_point(&snapshot).start,
1275            language::Point::new(1, 3)
1276        );
1277        assert_eq!(prediction.edits[0].1, " are you?");
1278    }
1279
1280    #[gpui::test]
1281    async fn test_request_events(cx: &mut TestAppContext) {
1282        let (zeta, mut req_rx) = init_test(cx);
1283        let fs = FakeFs::new(cx.executor());
1284        fs.insert_tree(
1285            "/root",
1286            json!({
1287                "foo.md": "Hello!\n\nBye"
1288            }),
1289        )
1290        .await;
1291        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1292
1293        let buffer = project
1294            .update(cx, |project, cx| {
1295                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1296                project.open_buffer(path, cx)
1297            })
1298            .await
1299            .unwrap();
1300
1301        zeta.update(cx, |zeta, cx| {
1302            zeta.register_buffer(&buffer, &project, cx);
1303        });
1304
1305        buffer.update(cx, |buffer, cx| {
1306            buffer.edit(vec![(7..7, "How")], None, cx);
1307        });
1308
1309        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1310        let position = snapshot.anchor_before(language::Point::new(1, 3));
1311
1312        let prediction_task = zeta.update(cx, |zeta, cx| {
1313            zeta.request_prediction(&project, &buffer, position, cx)
1314        });
1315
1316        let (request, respond_tx) = req_rx.next().await.unwrap();
1317
1318        assert_eq!(request.events.len(), 1);
1319        assert_eq!(
1320            request.events[0],
1321            predict_edits_v3::Event::BufferChange {
1322                path: Some(PathBuf::from(path!("root/foo.md"))),
1323                old_path: None,
1324                diff: indoc! {"
1325                        @@ -1,3 +1,3 @@
1326                         Hello!
1327                        -
1328                        +How
1329                         Bye
1330                    "}
1331                .to_string(),
1332                predicted: false
1333            }
1334        );
1335
1336        respond_tx
1337            .send(predict_edits_v3::PredictEditsResponse {
1338                request_id: Uuid::new_v4(),
1339                edits: vec![predict_edits_v3::Edit {
1340                    path: Path::new(path!("root/foo.md")).into(),
1341                    range: Line(0)..Line(snapshot.max_point().row + 1),
1342                    content: "Hello!\nHow are you?\nBye".into(),
1343                }],
1344                debug_info: None,
1345            })
1346            .unwrap();
1347
1348        let prediction = prediction_task.await.unwrap().unwrap();
1349
1350        assert_eq!(prediction.edits.len(), 1);
1351        assert_eq!(
1352            prediction.edits[0].0.to_point(&snapshot).start,
1353            language::Point::new(1, 3)
1354        );
1355        assert_eq!(prediction.edits[0].1, " are you?");
1356    }
1357
1358    #[gpui::test]
1359    async fn test_request_diagnostics(cx: &mut TestAppContext) {
1360        let (zeta, mut req_rx) = init_test(cx);
1361        let fs = FakeFs::new(cx.executor());
1362        fs.insert_tree(
1363            "/root",
1364            json!({
1365                "foo.md": "Hello!\nBye"
1366            }),
1367        )
1368        .await;
1369        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1370
1371        let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1372        let diagnostic = lsp::Diagnostic {
1373            range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1374            severity: Some(lsp::DiagnosticSeverity::ERROR),
1375            message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1376            ..Default::default()
1377        };
1378
1379        project.update(cx, |project, cx| {
1380            project.lsp_store().update(cx, |lsp_store, cx| {
1381                // Create some diagnostics
1382                lsp_store
1383                    .update_diagnostics(
1384                        LanguageServerId(0),
1385                        lsp::PublishDiagnosticsParams {
1386                            uri: path_to_buffer_uri.clone(),
1387                            diagnostics: vec![diagnostic],
1388                            version: None,
1389                        },
1390                        None,
1391                        language::DiagnosticSourceKind::Pushed,
1392                        &[],
1393                        cx,
1394                    )
1395                    .unwrap();
1396            });
1397        });
1398
1399        let buffer = project
1400            .update(cx, |project, cx| {
1401                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1402                project.open_buffer(path, cx)
1403            })
1404            .await
1405            .unwrap();
1406
1407        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1408        let position = snapshot.anchor_before(language::Point::new(0, 0));
1409
1410        let _prediction_task = zeta.update(cx, |zeta, cx| {
1411            zeta.request_prediction(&project, &buffer, position, cx)
1412        });
1413
1414        let (request, _respond_tx) = req_rx.next().await.unwrap();
1415
1416        assert_eq!(request.diagnostic_groups.len(), 1);
1417        let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1418            .unwrap();
1419        // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1420        assert_eq!(
1421            value,
1422            json!({
1423                "entries": [{
1424                    "range": {
1425                        "start": 8,
1426                        "end": 10
1427                    },
1428                    "diagnostic": {
1429                        "source": null,
1430                        "code": null,
1431                        "code_description": null,
1432                        "severity": 1,
1433                        "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1434                        "markdown": null,
1435                        "group_id": 0,
1436                        "is_primary": true,
1437                        "is_disk_based": false,
1438                        "is_unnecessary": false,
1439                        "source_kind": "Pushed",
1440                        "data": null,
1441                        "underline": true
1442                    }
1443                }],
1444                "primary_ix": 0
1445            })
1446        );
1447    }
1448
1449    fn init_test(
1450        cx: &mut TestAppContext,
1451    ) -> (
1452        Entity<Zeta>,
1453        mpsc::UnboundedReceiver<(
1454            predict_edits_v3::PredictEditsRequest,
1455            oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1456        )>,
1457    ) {
1458        cx.update(move |cx| {
1459            let settings_store = SettingsStore::test(cx);
1460            cx.set_global(settings_store);
1461            language::init(cx);
1462            Project::init_settings(cx);
1463
1464            let (req_tx, req_rx) = mpsc::unbounded();
1465
1466            let http_client = FakeHttpClient::create({
1467                move |req| {
1468                    let uri = req.uri().path().to_string();
1469                    let mut body = req.into_body();
1470                    let req_tx = req_tx.clone();
1471                    async move {
1472                        let resp = match uri.as_str() {
1473                            "/client/llm_tokens" => serde_json::to_string(&json!({
1474                                "token": "test"
1475                            }))
1476                            .unwrap(),
1477                            "/predict_edits/v3" => {
1478                                let mut buf = Vec::new();
1479                                body.read_to_end(&mut buf).await.ok();
1480                                let req = serde_json::from_slice(&buf).unwrap();
1481
1482                                let (res_tx, res_rx) = oneshot::channel();
1483                                req_tx.unbounded_send((req, res_tx)).unwrap();
1484                                serde_json::to_string(&res_rx.await?).unwrap()
1485                            }
1486                            _ => {
1487                                panic!("Unexpected path: {}", uri)
1488                            }
1489                        };
1490
1491                        Ok(Response::builder().body(resp.into()).unwrap())
1492                    }
1493                }
1494            });
1495
1496            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1497            client.cloud_client().set_credentials(1, "test".into());
1498
1499            language_model::init(client.clone(), cx);
1500
1501            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1502            let zeta = Zeta::global(&client, &user_store, cx);
1503
1504            (zeta, req_rx)
1505        })
1506    }
1507}