zeta2.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use chrono::TimeDelta;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
   7    ZED_VERSION_HEADER_NAME,
   8};
   9use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, PlannedPrompt};
  10use edit_prediction_context::{
  11    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  12    EditPredictionExcerptOptions, EditPredictionScoreOptions, SyntaxIndex, SyntaxIndexState,
  13};
  14use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
  15use futures::channel::{mpsc, oneshot};
  16use futures::future::{BoxFuture, Shared};
  17use futures::{AsyncReadExt as _, FutureExt};
  18use gpui::http_client::{AsyncBody, Method};
  19use gpui::{
  20    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  21    http_client, prelude::*,
  22};
  23use language::BufferSnapshot;
  24use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
  25use language_model::{LlmApiToken, RefreshLlmTokenListener};
  26use project::Project;
  27use release_channel::AppVersion;
  28use serde::de::DeserializeOwned;
  29use std::collections::{HashMap, VecDeque, hash_map};
  30use std::path::Path;
  31use std::str::FromStr as _;
  32use std::sync::Arc;
  33use std::time::{Duration, Instant};
  34use thiserror::Error;
  35use util::rel_path::RelPathBuf;
  36use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  37
  38mod prediction;
  39mod provider;
  40
  41use crate::prediction::EditPrediction;
  42pub use provider::ZetaEditPredictionProvider;
  43
  44const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  45
  46/// Maximum number of events to track.
  47const MAX_EVENT_COUNT: usize = 16;
  48
  49pub const DEFAULT_CONTEXT_OPTIONS: EditPredictionContextOptions = EditPredictionContextOptions {
  50    use_imports: true,
  51    max_retrieved_declarations: 0,
  52    excerpt: EditPredictionExcerptOptions {
  53        max_bytes: 512,
  54        min_bytes: 128,
  55        target_before_cursor_over_total_bytes: 0.5,
  56    },
  57    score: EditPredictionScoreOptions {
  58        omit_excerpt_overlaps: true,
  59    },
  60};
  61
  62pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  63    context: DEFAULT_CONTEXT_OPTIONS,
  64    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
  65    max_diagnostic_bytes: 2048,
  66    prompt_format: PromptFormat::DEFAULT,
  67    file_indexing_parallelism: 1,
  68};
  69
  70pub struct Zeta2FeatureFlag;
  71
  72impl FeatureFlag for Zeta2FeatureFlag {
  73    const NAME: &'static str = "zeta2";
  74
  75    fn enabled_for_staff() -> bool {
  76        false
  77    }
  78}
  79
  80#[derive(Clone)]
  81struct ZetaGlobal(Entity<Zeta>);
  82
  83impl Global for ZetaGlobal {}
  84
  85pub struct Zeta {
  86    client: Arc<Client>,
  87    user_store: Entity<UserStore>,
  88    llm_token: LlmApiToken,
  89    _llm_token_subscription: Subscription,
  90    projects: HashMap<EntityId, ZetaProject>,
  91    options: ZetaOptions,
  92    update_required: bool,
  93    observe_tx: Option<mpsc::UnboundedSender<ObservedPredictionRequest>>,
  94}
  95
  96#[derive(Debug, Clone, PartialEq)]
  97pub struct ZetaOptions {
  98    pub context: EditPredictionContextOptions,
  99    pub max_prompt_bytes: usize,
 100    pub max_diagnostic_bytes: usize,
 101    pub prompt_format: predict_edits_v3::PromptFormat,
 102    pub file_indexing_parallelism: usize,
 103}
 104
 105#[derive(Clone)]
 106pub struct ObservedPredictionRequest {
 107    pub full_request: predict_edits_v3::PredictEditsRequest,
 108    pub retrieval_time: TimeDelta,
 109    pub buffer: WeakEntity<Buffer>,
 110    pub position: language::Anchor,
 111    pub local_prompt: Result<String, String>,
 112    pub response:
 113        Shared<BoxFuture<'static, Result<predict_edits_v3::PredictEditsResponse, String>>>,
 114}
 115
 116pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 117
 118struct ZetaProject {
 119    syntax_index: Entity<SyntaxIndex>,
 120    events: VecDeque<Event>,
 121    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 122    current_prediction: Option<CurrentEditPrediction>,
 123}
 124
 125#[derive(Debug, Clone)]
 126struct CurrentEditPrediction {
 127    pub requested_by_buffer_id: EntityId,
 128    pub prediction: EditPrediction,
 129}
 130
 131impl CurrentEditPrediction {
 132    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 133        let Some(new_edits) = self
 134            .prediction
 135            .interpolate(&self.prediction.buffer.read(cx))
 136        else {
 137            return false;
 138        };
 139
 140        if self.prediction.buffer != old_prediction.prediction.buffer {
 141            return true;
 142        }
 143
 144        let Some(old_edits) = old_prediction
 145            .prediction
 146            .interpolate(&old_prediction.prediction.buffer.read(cx))
 147        else {
 148            return true;
 149        };
 150
 151        // This reduces the occurrence of UI thrash from replacing edits
 152        //
 153        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 154        if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
 155            && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
 156            && old_edits.len() == 1
 157            && new_edits.len() == 1
 158        {
 159            let (old_range, old_text) = &old_edits[0];
 160            let (new_range, new_text) = &new_edits[0];
 161            new_range == old_range && new_text.starts_with(old_text)
 162        } else {
 163            true
 164        }
 165    }
 166}
 167
 168/// A prediction from the perspective of a buffer.
 169#[derive(Debug)]
 170enum BufferEditPrediction<'a> {
 171    Local { prediction: &'a EditPrediction },
 172    Jump { prediction: &'a EditPrediction },
 173}
 174
 175struct RegisteredBuffer {
 176    snapshot: BufferSnapshot,
 177    _subscriptions: [gpui::Subscription; 2],
 178}
 179
 180#[derive(Clone)]
 181pub enum Event {
 182    BufferChange {
 183        old_snapshot: BufferSnapshot,
 184        new_snapshot: BufferSnapshot,
 185        timestamp: Instant,
 186    },
 187}
 188
 189impl Zeta {
 190    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 191        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 192    }
 193
 194    pub fn global(
 195        client: &Arc<Client>,
 196        user_store: &Entity<UserStore>,
 197        cx: &mut App,
 198    ) -> Entity<Self> {
 199        cx.try_global::<ZetaGlobal>()
 200            .map(|global| global.0.clone())
 201            .unwrap_or_else(|| {
 202                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 203                cx.set_global(ZetaGlobal(zeta.clone()));
 204                zeta
 205            })
 206    }
 207
 208    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 209        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 210
 211        Self {
 212            projects: HashMap::new(),
 213            client,
 214            user_store,
 215            options: DEFAULT_OPTIONS,
 216            llm_token: LlmApiToken::default(),
 217            _llm_token_subscription: cx.subscribe(
 218                &refresh_llm_token_listener,
 219                |this, _listener, _event, cx| {
 220                    let client = this.client.clone();
 221                    let llm_token = this.llm_token.clone();
 222                    cx.spawn(async move |_this, _cx| {
 223                        llm_token.refresh(&client).await?;
 224                        anyhow::Ok(())
 225                    })
 226                    .detach_and_log_err(cx);
 227                },
 228            ),
 229            update_required: false,
 230            observe_tx: None,
 231        }
 232    }
 233
 234    pub fn observe_requests(&mut self) -> mpsc::UnboundedReceiver<ObservedPredictionRequest> {
 235        let (tx, rx) = mpsc::unbounded();
 236        self.observe_tx = Some(tx);
 237        rx
 238    }
 239
 240    pub fn options(&self) -> &ZetaOptions {
 241        &self.options
 242    }
 243
 244    pub fn set_options(&mut self, options: ZetaOptions) {
 245        self.options = options;
 246    }
 247
 248    pub fn clear_history(&mut self) {
 249        for zeta_project in self.projects.values_mut() {
 250            zeta_project.events.clear();
 251        }
 252    }
 253
 254    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 255        self.user_store.read(cx).edit_prediction_usage()
 256    }
 257
 258    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 259        self.get_or_init_zeta_project(project, cx);
 260    }
 261
 262    pub fn register_buffer(
 263        &mut self,
 264        buffer: &Entity<Buffer>,
 265        project: &Entity<Project>,
 266        cx: &mut Context<Self>,
 267    ) {
 268        let zeta_project = self.get_or_init_zeta_project(project, cx);
 269        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 270    }
 271
 272    fn get_or_init_zeta_project(
 273        &mut self,
 274        project: &Entity<Project>,
 275        cx: &mut App,
 276    ) -> &mut ZetaProject {
 277        self.projects
 278            .entry(project.entity_id())
 279            .or_insert_with(|| ZetaProject {
 280                syntax_index: cx.new(|cx| {
 281                    SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 282                }),
 283                events: VecDeque::new(),
 284                registered_buffers: HashMap::new(),
 285                current_prediction: None,
 286            })
 287    }
 288
 289    fn register_buffer_impl<'a>(
 290        zeta_project: &'a mut ZetaProject,
 291        buffer: &Entity<Buffer>,
 292        project: &Entity<Project>,
 293        cx: &mut Context<Self>,
 294    ) -> &'a mut RegisteredBuffer {
 295        let buffer_id = buffer.entity_id();
 296        match zeta_project.registered_buffers.entry(buffer_id) {
 297            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 298            hash_map::Entry::Vacant(entry) => {
 299                let snapshot = buffer.read(cx).snapshot();
 300                let project_entity_id = project.entity_id();
 301                entry.insert(RegisteredBuffer {
 302                    snapshot,
 303                    _subscriptions: [
 304                        cx.subscribe(buffer, {
 305                            let project = project.downgrade();
 306                            move |this, buffer, event, cx| {
 307                                if let language::BufferEvent::Edited = event
 308                                    && let Some(project) = project.upgrade()
 309                                {
 310                                    this.report_changes_for_buffer(&buffer, &project, cx);
 311                                }
 312                            }
 313                        }),
 314                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 315                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 316                            else {
 317                                return;
 318                            };
 319                            zeta_project.registered_buffers.remove(&buffer_id);
 320                        }),
 321                    ],
 322                })
 323            }
 324        }
 325    }
 326
 327    fn report_changes_for_buffer(
 328        &mut self,
 329        buffer: &Entity<Buffer>,
 330        project: &Entity<Project>,
 331        cx: &mut Context<Self>,
 332    ) -> BufferSnapshot {
 333        let zeta_project = self.get_or_init_zeta_project(project, cx);
 334        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 335
 336        let new_snapshot = buffer.read(cx).snapshot();
 337        if new_snapshot.version != registered_buffer.snapshot.version {
 338            let old_snapshot =
 339                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 340            Self::push_event(
 341                zeta_project,
 342                Event::BufferChange {
 343                    old_snapshot,
 344                    new_snapshot: new_snapshot.clone(),
 345                    timestamp: Instant::now(),
 346                },
 347            );
 348        }
 349
 350        new_snapshot
 351    }
 352
 353    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
 354        let events = &mut zeta_project.events;
 355
 356        if let Some(Event::BufferChange {
 357            new_snapshot: last_new_snapshot,
 358            timestamp: last_timestamp,
 359            ..
 360        }) = events.back_mut()
 361        {
 362            // Coalesce edits for the same buffer when they happen one after the other.
 363            let Event::BufferChange {
 364                old_snapshot,
 365                new_snapshot,
 366                timestamp,
 367            } = &event;
 368
 369            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 370                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 371                && old_snapshot.version == last_new_snapshot.version
 372            {
 373                *last_new_snapshot = new_snapshot.clone();
 374                *last_timestamp = *timestamp;
 375                return;
 376            }
 377        }
 378
 379        if events.len() >= MAX_EVENT_COUNT {
 380            // These are halved instead of popping to improve prompt caching.
 381            events.drain(..MAX_EVENT_COUNT / 2);
 382        }
 383
 384        events.push_back(event);
 385    }
 386
 387    fn current_prediction_for_buffer(
 388        &self,
 389        buffer: &Entity<Buffer>,
 390        project: &Entity<Project>,
 391        cx: &App,
 392    ) -> Option<BufferEditPrediction<'_>> {
 393        let project_state = self.projects.get(&project.entity_id())?;
 394
 395        let CurrentEditPrediction {
 396            requested_by_buffer_id,
 397            prediction,
 398        } = project_state.current_prediction.as_ref()?;
 399
 400        if prediction.targets_buffer(buffer.read(cx), cx) {
 401            Some(BufferEditPrediction::Local { prediction })
 402        } else if *requested_by_buffer_id == buffer.entity_id() {
 403            Some(BufferEditPrediction::Jump { prediction })
 404        } else {
 405            None
 406        }
 407    }
 408
 409    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 410        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 411            return;
 412        };
 413
 414        let Some(prediction) = project_state.current_prediction.take() else {
 415            return;
 416        };
 417        let request_id = prediction.prediction.id.into();
 418
 419        let client = self.client.clone();
 420        let llm_token = self.llm_token.clone();
 421        let app_version = AppVersion::global(cx);
 422        cx.spawn(async move |this, cx| {
 423            let url = if let Ok(predict_edits_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
 424                http_client::Url::parse(&predict_edits_url)?
 425            } else {
 426                client
 427                    .http_client()
 428                    .build_zed_llm_url("/predict_edits/accept", &[])?
 429            };
 430
 431            let response = cx
 432                .background_spawn(Self::send_api_request::<()>(
 433                    move |builder| {
 434                        let req = builder.uri(url.as_ref()).body(
 435                            serde_json::to_string(&AcceptEditPredictionBody { request_id })?.into(),
 436                        );
 437                        Ok(req?)
 438                    },
 439                    client,
 440                    llm_token,
 441                    app_version,
 442                ))
 443                .await;
 444
 445            Self::handle_api_response(&this, response, cx)?;
 446            anyhow::Ok(())
 447        })
 448        .detach_and_log_err(cx);
 449    }
 450
 451    fn discard_current_prediction(&mut self, project: &Entity<Project>) {
 452        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 453            project_state.current_prediction.take();
 454        };
 455    }
 456
 457    pub fn refresh_prediction(
 458        &mut self,
 459        project: &Entity<Project>,
 460        buffer: &Entity<Buffer>,
 461        position: language::Anchor,
 462        cx: &mut Context<Self>,
 463    ) -> Task<Result<()>> {
 464        let request_task = self.request_prediction(project, buffer, position, cx);
 465        let buffer = buffer.clone();
 466        let project = project.clone();
 467
 468        cx.spawn(async move |this, cx| {
 469            if let Some(prediction) = request_task.await? {
 470                this.update(cx, |this, cx| {
 471                    let project_state = this
 472                        .projects
 473                        .get_mut(&project.entity_id())
 474                        .context("Project not found")?;
 475
 476                    let new_prediction = CurrentEditPrediction {
 477                        requested_by_buffer_id: buffer.entity_id(),
 478                        prediction: prediction,
 479                    };
 480
 481                    if project_state
 482                        .current_prediction
 483                        .as_ref()
 484                        .is_none_or(|old_prediction| {
 485                            new_prediction.should_replace_prediction(&old_prediction, cx)
 486                        })
 487                    {
 488                        project_state.current_prediction = Some(new_prediction);
 489                    }
 490                    anyhow::Ok(())
 491                })??;
 492            }
 493            Ok(())
 494        })
 495    }
 496
 497    fn request_prediction(
 498        &mut self,
 499        project: &Entity<Project>,
 500        buffer: &Entity<Buffer>,
 501        position: language::Anchor,
 502        cx: &mut Context<Self>,
 503    ) -> Task<Result<Option<EditPrediction>>> {
 504        let project_state = self.projects.get(&project.entity_id());
 505
 506        let index_state = project_state.map(|state| {
 507            state
 508                .syntax_index
 509                .read_with(cx, |index, _cx| index.state().clone())
 510        });
 511        let options = self.options.clone();
 512        let snapshot = buffer.read(cx).snapshot();
 513        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx).into()) else {
 514            return Task::ready(Err(anyhow!("No file path for excerpt")));
 515        };
 516        let client = self.client.clone();
 517        let llm_token = self.llm_token.clone();
 518        let app_version = AppVersion::global(cx);
 519        let worktree_snapshots = project
 520            .read(cx)
 521            .worktrees(cx)
 522            .map(|worktree| worktree.read(cx).snapshot())
 523            .collect::<Vec<_>>();
 524        let observe_tx = self.observe_tx.clone();
 525
 526        let events = project_state
 527            .map(|state| {
 528                state
 529                    .events
 530                    .iter()
 531                    .filter_map(|event| match event {
 532                        Event::BufferChange {
 533                            old_snapshot,
 534                            new_snapshot,
 535                            ..
 536                        } => {
 537                            let path = new_snapshot.file().map(|f| f.full_path(cx));
 538
 539                            let old_path = old_snapshot.file().and_then(|f| {
 540                                let old_path = f.full_path(cx);
 541                                if Some(&old_path) != path.as_ref() {
 542                                    Some(old_path)
 543                                } else {
 544                                    None
 545                                }
 546                            });
 547
 548                            // TODO [zeta2] move to bg?
 549                            let diff =
 550                                language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
 551
 552                            if path == old_path && diff.is_empty() {
 553                                None
 554                            } else {
 555                                Some(predict_edits_v3::Event::BufferChange {
 556                                    old_path,
 557                                    path,
 558                                    diff,
 559                                    //todo: Actually detect if this edit was predicted or not
 560                                    predicted: false,
 561                                })
 562                            }
 563                        }
 564                    })
 565                    .collect::<Vec<_>>()
 566            })
 567            .unwrap_or_default();
 568
 569        let diagnostics = snapshot.diagnostic_sets().clone();
 570
 571        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
 572            let mut path = f.worktree.read(cx).absolutize(&f.path);
 573            if path.pop() { Some(path) } else { None }
 574        });
 575
 576        // TODO data collection
 577        let can_collect_data = cx.is_staff();
 578
 579        let request_task = cx.background_spawn({
 580            let snapshot = snapshot.clone();
 581            let buffer = buffer.clone();
 582            async move {
 583                let index_state = if let Some(index_state) = index_state {
 584                    Some(index_state.lock_owned().await)
 585                } else {
 586                    None
 587                };
 588
 589                let cursor_offset = position.to_offset(&snapshot);
 590                let cursor_point = cursor_offset.to_point(&snapshot);
 591
 592                let before_retrieval = chrono::Utc::now();
 593
 594                let Some(context) = EditPredictionContext::gather_context(
 595                    cursor_point,
 596                    &snapshot,
 597                    parent_abs_path.as_deref(),
 598                    &options.context,
 599                    index_state.as_deref(),
 600                ) else {
 601                    return Ok((None, None));
 602                };
 603
 604                let retrieval_time = chrono::Utc::now() - before_retrieval;
 605
 606                let (diagnostic_groups, diagnostic_groups_truncated) =
 607                    Self::gather_nearby_diagnostics(
 608                        cursor_offset,
 609                        &diagnostics,
 610                        &snapshot,
 611                        options.max_diagnostic_bytes,
 612                    );
 613
 614                let request = make_cloud_request(
 615                    excerpt_path,
 616                    context,
 617                    events,
 618                    can_collect_data,
 619                    diagnostic_groups,
 620                    diagnostic_groups_truncated,
 621                    None,
 622                    observe_tx.is_some(),
 623                    &worktree_snapshots,
 624                    index_state.as_deref(),
 625                    Some(options.max_prompt_bytes),
 626                    options.prompt_format,
 627                );
 628
 629                let observe_response_tx = if let Some(observe_tx) = &observe_tx {
 630                    let (response_tx, response_rx) = oneshot::channel();
 631
 632                    let local_prompt = PlannedPrompt::populate(&request)
 633                        .and_then(|p| p.to_prompt_string().map(|p| p.0))
 634                        .map_err(|err| err.to_string());
 635
 636                    observe_tx
 637                        .unbounded_send(ObservedPredictionRequest {
 638                            full_request: request.clone(),
 639                            retrieval_time,
 640                            buffer: buffer.downgrade(),
 641                            local_prompt,
 642                            position,
 643                            response: async move {
 644                                let resp = response_rx
 645                                    .await
 646                                    .map_err(|_: oneshot::Canceled| "Canceled".to_string());
 647                                Ok(resp??)
 648                            }
 649                            .boxed()
 650                            .shared(),
 651                        })
 652                        .ok();
 653                    Some(response_tx)
 654                } else {
 655                    None
 656                };
 657
 658                if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
 659                    if let Some(observe_response_tx) = observe_response_tx {
 660                        observe_response_tx
 661                            .send(Err("Request skipped".to_string()))
 662                            .ok();
 663                    }
 664                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
 665                }
 666
 667                let response =
 668                    Self::send_prediction_request(client, llm_token, app_version, request).await;
 669
 670                if let Some(debug_response_tx) = observe_response_tx {
 671                    debug_response_tx
 672                        .send(
 673                            response
 674                                .as_ref()
 675                                .map_err(|err| err.to_string())
 676                                .map(|response| response.0.clone()),
 677                        )
 678                        .ok();
 679                }
 680
 681                response.map(|(res, usage)| (Some(res), usage))
 682            }
 683        });
 684
 685        let buffer = buffer.clone();
 686
 687        cx.spawn({
 688            let project = project.clone();
 689            async move |this, cx| {
 690                let Some(response) = Self::handle_api_response(&this, request_task.await, cx)?
 691                else {
 692                    return Ok(None);
 693                };
 694
 695                // TODO telemetry: duration, etc
 696                Ok(EditPrediction::from_response(response, &snapshot, &buffer, &project, cx).await)
 697            }
 698        })
 699    }
 700
 701    async fn send_prediction_request(
 702        client: Arc<Client>,
 703        llm_token: LlmApiToken,
 704        app_version: SemanticVersion,
 705        request: predict_edits_v3::PredictEditsRequest,
 706    ) -> Result<(
 707        predict_edits_v3::PredictEditsResponse,
 708        Option<EditPredictionUsage>,
 709    )> {
 710        let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 711            http_client::Url::parse(&predict_edits_url)?
 712        } else {
 713            client
 714                .http_client()
 715                .build_zed_llm_url("/predict_edits/v3", &[])?
 716        };
 717
 718        Self::send_api_request(
 719            |builder| {
 720                let req = builder
 721                    .uri(url.as_ref())
 722                    .body(serde_json::to_string(&request)?.into());
 723                Ok(req?)
 724            },
 725            client,
 726            llm_token,
 727            app_version,
 728        )
 729        .await
 730    }
 731
 732    fn handle_api_response<T>(
 733        this: &WeakEntity<Self>,
 734        response: Result<(T, Option<EditPredictionUsage>)>,
 735        cx: &mut gpui::AsyncApp,
 736    ) -> Result<T> {
 737        match response {
 738            Ok((data, usage)) => {
 739                if let Some(usage) = usage {
 740                    this.update(cx, |this, cx| {
 741                        this.user_store.update(cx, |user_store, cx| {
 742                            user_store.update_edit_prediction_usage(usage, cx);
 743                        });
 744                    })
 745                    .ok();
 746                }
 747                Ok(data)
 748            }
 749            Err(err) => {
 750                if err.is::<ZedUpdateRequiredError>() {
 751                    cx.update(|cx| {
 752                        this.update(cx, |this, _cx| {
 753                            this.update_required = true;
 754                        })
 755                        .ok();
 756
 757                        let error_message: SharedString = err.to_string().into();
 758                        show_app_notification(
 759                            NotificationId::unique::<ZedUpdateRequiredError>(),
 760                            cx,
 761                            move |cx| {
 762                                cx.new(|cx| {
 763                                    ErrorMessagePrompt::new(error_message.clone(), cx)
 764                                        .with_link_button("Update Zed", "https://zed.dev/releases")
 765                                })
 766                            },
 767                        );
 768                    })
 769                    .ok();
 770                }
 771                Err(err)
 772            }
 773        }
 774    }
 775
 776    async fn send_api_request<Res>(
 777        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
 778        client: Arc<Client>,
 779        llm_token: LlmApiToken,
 780        app_version: SemanticVersion,
 781    ) -> Result<(Res, Option<EditPredictionUsage>)>
 782    where
 783        Res: DeserializeOwned,
 784    {
 785        let http_client = client.http_client();
 786        let mut token = llm_token.acquire(&client).await?;
 787        let mut did_retry = false;
 788
 789        loop {
 790            let request_builder = http_client::Request::builder().method(Method::POST);
 791
 792            let request = build(
 793                request_builder
 794                    .header("Content-Type", "application/json")
 795                    .header("Authorization", format!("Bearer {}", token))
 796                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
 797            )?;
 798
 799            let mut response = http_client.send(request).await?;
 800
 801            if let Some(minimum_required_version) = response
 802                .headers()
 803                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 804                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 805            {
 806                anyhow::ensure!(
 807                    app_version >= minimum_required_version,
 808                    ZedUpdateRequiredError {
 809                        minimum_version: minimum_required_version
 810                    }
 811                );
 812            }
 813
 814            if response.status().is_success() {
 815                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 816
 817                let mut body = Vec::new();
 818                response.body_mut().read_to_end(&mut body).await?;
 819                return Ok((serde_json::from_slice(&body)?, usage));
 820            } else if !did_retry
 821                && response
 822                    .headers()
 823                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 824                    .is_some()
 825            {
 826                did_retry = true;
 827                token = llm_token.refresh(&client).await?;
 828            } else {
 829                let mut body = String::new();
 830                response.body_mut().read_to_string(&mut body).await?;
 831                anyhow::bail!(
 832                    "Request failed with status: {:?}\nBody: {}",
 833                    response.status(),
 834                    body
 835                );
 836            }
 837        }
 838    }
 839
 840    fn gather_nearby_diagnostics(
 841        cursor_offset: usize,
 842        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
 843        snapshot: &BufferSnapshot,
 844        max_diagnostics_bytes: usize,
 845    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
 846        // TODO: Could make this more efficient
 847        let mut diagnostic_groups = Vec::new();
 848        for (language_server_id, diagnostics) in diagnostic_sets {
 849            let mut groups = Vec::new();
 850            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
 851            diagnostic_groups.extend(
 852                groups
 853                    .into_iter()
 854                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
 855            );
 856        }
 857
 858        // sort by proximity to cursor
 859        diagnostic_groups.sort_by_key(|group| {
 860            let range = &group.entries[group.primary_ix].range;
 861            if range.start >= cursor_offset {
 862                range.start - cursor_offset
 863            } else if cursor_offset >= range.end {
 864                cursor_offset - range.end
 865            } else {
 866                (cursor_offset - range.start).min(range.end - cursor_offset)
 867            }
 868        });
 869
 870        let mut results = Vec::new();
 871        let mut diagnostic_groups_truncated = false;
 872        let mut diagnostics_byte_count = 0;
 873        for group in diagnostic_groups {
 874            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
 875            diagnostics_byte_count += raw_value.get().len();
 876            if diagnostics_byte_count > max_diagnostics_bytes {
 877                diagnostic_groups_truncated = true;
 878                break;
 879            }
 880            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
 881        }
 882
 883        (results, diagnostic_groups_truncated)
 884    }
 885
 886    // TODO: Dedupe with similar code in request_prediction?
 887    pub fn cloud_request_for_zeta_cli(
 888        &mut self,
 889        project: &Entity<Project>,
 890        buffer: &Entity<Buffer>,
 891        position: language::Anchor,
 892        cx: &mut Context<Self>,
 893    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
 894        let project_state = self.projects.get(&project.entity_id());
 895
 896        let index_state = project_state.map(|state| {
 897            state
 898                .syntax_index
 899                .read_with(cx, |index, _cx| index.state().clone())
 900        });
 901        let options = self.options.clone();
 902        let snapshot = buffer.read(cx).snapshot();
 903        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
 904            return Task::ready(Err(anyhow!("No file path for excerpt")));
 905        };
 906        let worktree_snapshots = project
 907            .read(cx)
 908            .worktrees(cx)
 909            .map(|worktree| worktree.read(cx).snapshot())
 910            .collect::<Vec<_>>();
 911
 912        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
 913            let mut path = f.worktree.read(cx).absolutize(&f.path);
 914            if path.pop() { Some(path) } else { None }
 915        });
 916
 917        cx.background_spawn(async move {
 918            let index_state = if let Some(index_state) = index_state {
 919                Some(index_state.lock_owned().await)
 920            } else {
 921                None
 922            };
 923
 924            let cursor_point = position.to_point(&snapshot);
 925
 926            let debug_info = true;
 927            EditPredictionContext::gather_context(
 928                cursor_point,
 929                &snapshot,
 930                parent_abs_path.as_deref(),
 931                &options.context,
 932                index_state.as_deref(),
 933            )
 934            .context("Failed to select excerpt")
 935            .map(|context| {
 936                make_cloud_request(
 937                    excerpt_path.into(),
 938                    context,
 939                    // TODO pass everything
 940                    Vec::new(),
 941                    false,
 942                    Vec::new(),
 943                    false,
 944                    None,
 945                    debug_info,
 946                    &worktree_snapshots,
 947                    index_state.as_deref(),
 948                    Some(options.max_prompt_bytes),
 949                    options.prompt_format,
 950                )
 951            })
 952        })
 953    }
 954
 955    pub fn wait_for_initial_indexing(
 956        &mut self,
 957        project: &Entity<Project>,
 958        cx: &mut App,
 959    ) -> Task<Result<()>> {
 960        let zeta_project = self.get_or_init_zeta_project(project, cx);
 961        zeta_project
 962            .syntax_index
 963            .read(cx)
 964            .wait_for_initial_file_indexing(cx)
 965    }
 966}
 967
 968#[derive(Error, Debug)]
 969#[error(
 970    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
 971)]
 972pub struct ZedUpdateRequiredError {
 973    minimum_version: SemanticVersion,
 974}
 975
 976fn make_cloud_request(
 977    excerpt_path: Arc<Path>,
 978    context: EditPredictionContext,
 979    events: Vec<predict_edits_v3::Event>,
 980    can_collect_data: bool,
 981    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
 982    diagnostic_groups_truncated: bool,
 983    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
 984    debug_info: bool,
 985    worktrees: &Vec<worktree::Snapshot>,
 986    index_state: Option<&SyntaxIndexState>,
 987    prompt_max_bytes: Option<usize>,
 988    prompt_format: PromptFormat,
 989) -> predict_edits_v3::PredictEditsRequest {
 990    let mut signatures = Vec::new();
 991    let mut declaration_to_signature_index = HashMap::default();
 992    let mut referenced_declarations = Vec::new();
 993
 994    for snippet in context.declarations {
 995        let project_entry_id = snippet.declaration.project_entry_id();
 996        let Some(path) = worktrees.iter().find_map(|worktree| {
 997            worktree.entry_for_id(project_entry_id).map(|entry| {
 998                let mut full_path = RelPathBuf::new();
 999                full_path.push(worktree.root_name());
1000                full_path.push(&entry.path);
1001                full_path
1002            })
1003        }) else {
1004            continue;
1005        };
1006
1007        let parent_index = index_state.and_then(|index_state| {
1008            snippet.declaration.parent().and_then(|parent| {
1009                add_signature(
1010                    parent,
1011                    &mut declaration_to_signature_index,
1012                    &mut signatures,
1013                    index_state,
1014                )
1015            })
1016        });
1017
1018        let (text, text_is_truncated) = snippet.declaration.item_text();
1019        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1020            path: path.as_std_path().into(),
1021            text: text.into(),
1022            range: snippet.declaration.item_line_range(),
1023            text_is_truncated,
1024            signature_range: snippet.declaration.signature_range_in_item_text(),
1025            parent_index,
1026            signature_score: snippet.score(DeclarationStyle::Signature),
1027            declaration_score: snippet.score(DeclarationStyle::Declaration),
1028            score_components: snippet.components,
1029        });
1030    }
1031
1032    let excerpt_parent = index_state.and_then(|index_state| {
1033        context
1034            .excerpt
1035            .parent_declarations
1036            .last()
1037            .and_then(|(parent, _)| {
1038                add_signature(
1039                    *parent,
1040                    &mut declaration_to_signature_index,
1041                    &mut signatures,
1042                    index_state,
1043                )
1044            })
1045    });
1046
1047    predict_edits_v3::PredictEditsRequest {
1048        excerpt_path,
1049        excerpt: context.excerpt_text.body,
1050        excerpt_line_range: context.excerpt.line_range,
1051        excerpt_range: context.excerpt.range,
1052        cursor_point: predict_edits_v3::Point {
1053            line: predict_edits_v3::Line(context.cursor_point.row),
1054            column: context.cursor_point.column,
1055        },
1056        referenced_declarations,
1057        signatures,
1058        excerpt_parent,
1059        events,
1060        can_collect_data,
1061        diagnostic_groups,
1062        diagnostic_groups_truncated,
1063        git_info,
1064        debug_info,
1065        prompt_max_bytes,
1066        prompt_format,
1067    }
1068}
1069
1070fn add_signature(
1071    declaration_id: DeclarationId,
1072    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1073    signatures: &mut Vec<Signature>,
1074    index: &SyntaxIndexState,
1075) -> Option<usize> {
1076    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1077        return Some(*signature_index);
1078    }
1079    let Some(parent_declaration) = index.declaration(declaration_id) else {
1080        log::error!("bug: missing parent declaration");
1081        return None;
1082    };
1083    let parent_index = parent_declaration.parent().and_then(|parent| {
1084        add_signature(parent, declaration_to_signature_index, signatures, index)
1085    });
1086    let (text, text_is_truncated) = parent_declaration.signature_text();
1087    let signature_index = signatures.len();
1088    signatures.push(Signature {
1089        text: text.into(),
1090        text_is_truncated,
1091        parent_index,
1092        range: parent_declaration.signature_line_range(),
1093    });
1094    declaration_to_signature_index.insert(declaration_id, signature_index);
1095    Some(signature_index)
1096}
1097
1098#[cfg(test)]
1099mod tests {
1100    use std::{
1101        path::{Path, PathBuf},
1102        sync::Arc,
1103    };
1104
1105    use client::UserStore;
1106    use clock::FakeSystemClock;
1107    use cloud_llm_client::predict_edits_v3::{self, Point};
1108    use edit_prediction_context::Line;
1109    use futures::{
1110        AsyncReadExt, StreamExt,
1111        channel::{mpsc, oneshot},
1112    };
1113    use gpui::{
1114        Entity, TestAppContext,
1115        http_client::{FakeHttpClient, Response},
1116        prelude::*,
1117    };
1118    use indoc::indoc;
1119    use language::{LanguageServerId, OffsetRangeExt as _};
1120    use pretty_assertions::{assert_eq, assert_matches};
1121    use project::{FakeFs, Project};
1122    use serde_json::json;
1123    use settings::SettingsStore;
1124    use util::path;
1125    use uuid::Uuid;
1126
1127    use crate::{BufferEditPrediction, Zeta};
1128
1129    #[gpui::test]
1130    async fn test_current_state(cx: &mut TestAppContext) {
1131        let (zeta, mut req_rx) = init_test(cx);
1132        let fs = FakeFs::new(cx.executor());
1133        fs.insert_tree(
1134            "/root",
1135            json!({
1136                "1.txt": "Hello!\nHow\nBye",
1137                "2.txt": "Hola!\nComo\nAdios"
1138            }),
1139        )
1140        .await;
1141        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1142
1143        zeta.update(cx, |zeta, cx| {
1144            zeta.register_project(&project, cx);
1145        });
1146
1147        let buffer1 = project
1148            .update(cx, |project, cx| {
1149                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1150                project.open_buffer(path, cx)
1151            })
1152            .await
1153            .unwrap();
1154        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1155        let position = snapshot1.anchor_before(language::Point::new(1, 3));
1156
1157        // Prediction for current file
1158
1159        let prediction_task = zeta.update(cx, |zeta, cx| {
1160            zeta.refresh_prediction(&project, &buffer1, position, cx)
1161        });
1162        let (_request, respond_tx) = req_rx.next().await.unwrap();
1163        respond_tx
1164            .send(predict_edits_v3::PredictEditsResponse {
1165                request_id: Uuid::new_v4(),
1166                edits: vec![predict_edits_v3::Edit {
1167                    path: Path::new(path!("root/1.txt")).into(),
1168                    range: Line(0)..Line(snapshot1.max_point().row + 1),
1169                    content: "Hello!\nHow are you?\nBye".into(),
1170                }],
1171                debug_info: None,
1172            })
1173            .unwrap();
1174        prediction_task.await.unwrap();
1175
1176        zeta.read_with(cx, |zeta, cx| {
1177            let prediction = zeta
1178                .current_prediction_for_buffer(&buffer1, &project, cx)
1179                .unwrap();
1180            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1181        });
1182
1183        // Prediction for another file
1184        let prediction_task = zeta.update(cx, |zeta, cx| {
1185            zeta.refresh_prediction(&project, &buffer1, position, cx)
1186        });
1187        let (_request, respond_tx) = req_rx.next().await.unwrap();
1188        respond_tx
1189            .send(predict_edits_v3::PredictEditsResponse {
1190                request_id: Uuid::new_v4(),
1191                edits: vec![predict_edits_v3::Edit {
1192                    path: Path::new(path!("root/2.txt")).into(),
1193                    range: Line(0)..Line(snapshot1.max_point().row + 1),
1194                    content: "Hola!\nComo estas?\nAdios".into(),
1195                }],
1196                debug_info: None,
1197            })
1198            .unwrap();
1199        prediction_task.await.unwrap();
1200        zeta.read_with(cx, |zeta, cx| {
1201            let prediction = zeta
1202                .current_prediction_for_buffer(&buffer1, &project, cx)
1203                .unwrap();
1204            assert_matches!(
1205                prediction,
1206                BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1207            );
1208        });
1209
1210        let buffer2 = project
1211            .update(cx, |project, cx| {
1212                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1213                project.open_buffer(path, cx)
1214            })
1215            .await
1216            .unwrap();
1217
1218        zeta.read_with(cx, |zeta, cx| {
1219            let prediction = zeta
1220                .current_prediction_for_buffer(&buffer2, &project, cx)
1221                .unwrap();
1222            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1223        });
1224    }
1225
1226    #[gpui::test]
1227    async fn test_simple_request(cx: &mut TestAppContext) {
1228        let (zeta, mut req_rx) = init_test(cx);
1229        let fs = FakeFs::new(cx.executor());
1230        fs.insert_tree(
1231            "/root",
1232            json!({
1233                "foo.md":  "Hello!\nHow\nBye"
1234            }),
1235        )
1236        .await;
1237        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1238
1239        let buffer = project
1240            .update(cx, |project, cx| {
1241                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1242                project.open_buffer(path, cx)
1243            })
1244            .await
1245            .unwrap();
1246        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1247        let position = snapshot.anchor_before(language::Point::new(1, 3));
1248
1249        let prediction_task = zeta.update(cx, |zeta, cx| {
1250            zeta.request_prediction(&project, &buffer, position, cx)
1251        });
1252
1253        let (request, respond_tx) = req_rx.next().await.unwrap();
1254        assert_eq!(
1255            request.excerpt_path.as_ref(),
1256            Path::new(path!("root/foo.md"))
1257        );
1258        assert_eq!(
1259            request.cursor_point,
1260            Point {
1261                line: Line(1),
1262                column: 3
1263            }
1264        );
1265
1266        respond_tx
1267            .send(predict_edits_v3::PredictEditsResponse {
1268                request_id: Uuid::new_v4(),
1269                edits: vec![predict_edits_v3::Edit {
1270                    path: Path::new(path!("root/foo.md")).into(),
1271                    range: Line(0)..Line(snapshot.max_point().row + 1),
1272                    content: "Hello!\nHow are you?\nBye".into(),
1273                }],
1274                debug_info: None,
1275            })
1276            .unwrap();
1277
1278        let prediction = prediction_task.await.unwrap().unwrap();
1279
1280        assert_eq!(prediction.edits.len(), 1);
1281        assert_eq!(
1282            prediction.edits[0].0.to_point(&snapshot).start,
1283            language::Point::new(1, 3)
1284        );
1285        assert_eq!(prediction.edits[0].1, " are you?");
1286    }
1287
1288    #[gpui::test]
1289    async fn test_request_events(cx: &mut TestAppContext) {
1290        let (zeta, mut req_rx) = init_test(cx);
1291        let fs = FakeFs::new(cx.executor());
1292        fs.insert_tree(
1293            "/root",
1294            json!({
1295                "foo.md": "Hello!\n\nBye"
1296            }),
1297        )
1298        .await;
1299        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1300
1301        let buffer = project
1302            .update(cx, |project, cx| {
1303                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1304                project.open_buffer(path, cx)
1305            })
1306            .await
1307            .unwrap();
1308
1309        zeta.update(cx, |zeta, cx| {
1310            zeta.register_buffer(&buffer, &project, cx);
1311        });
1312
1313        buffer.update(cx, |buffer, cx| {
1314            buffer.edit(vec![(7..7, "How")], None, cx);
1315        });
1316
1317        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1318        let position = snapshot.anchor_before(language::Point::new(1, 3));
1319
1320        let prediction_task = zeta.update(cx, |zeta, cx| {
1321            zeta.request_prediction(&project, &buffer, position, cx)
1322        });
1323
1324        let (request, respond_tx) = req_rx.next().await.unwrap();
1325
1326        assert_eq!(request.events.len(), 1);
1327        assert_eq!(
1328            request.events[0],
1329            predict_edits_v3::Event::BufferChange {
1330                path: Some(PathBuf::from(path!("root/foo.md"))),
1331                old_path: None,
1332                diff: indoc! {"
1333                        @@ -1,3 +1,3 @@
1334                         Hello!
1335                        -
1336                        +How
1337                         Bye
1338                    "}
1339                .to_string(),
1340                predicted: false
1341            }
1342        );
1343
1344        respond_tx
1345            .send(predict_edits_v3::PredictEditsResponse {
1346                request_id: Uuid::new_v4(),
1347                edits: vec![predict_edits_v3::Edit {
1348                    path: Path::new(path!("root/foo.md")).into(),
1349                    range: Line(0)..Line(snapshot.max_point().row + 1),
1350                    content: "Hello!\nHow are you?\nBye".into(),
1351                }],
1352                debug_info: None,
1353            })
1354            .unwrap();
1355
1356        let prediction = prediction_task.await.unwrap().unwrap();
1357
1358        assert_eq!(prediction.edits.len(), 1);
1359        assert_eq!(
1360            prediction.edits[0].0.to_point(&snapshot).start,
1361            language::Point::new(1, 3)
1362        );
1363        assert_eq!(prediction.edits[0].1, " are you?");
1364    }
1365
1366    #[gpui::test]
1367    async fn test_request_diagnostics(cx: &mut TestAppContext) {
1368        let (zeta, mut req_rx) = init_test(cx);
1369        let fs = FakeFs::new(cx.executor());
1370        fs.insert_tree(
1371            "/root",
1372            json!({
1373                "foo.md": "Hello!\nBye"
1374            }),
1375        )
1376        .await;
1377        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1378
1379        let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1380        let diagnostic = lsp::Diagnostic {
1381            range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1382            severity: Some(lsp::DiagnosticSeverity::ERROR),
1383            message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1384            ..Default::default()
1385        };
1386
1387        project.update(cx, |project, cx| {
1388            project.lsp_store().update(cx, |lsp_store, cx| {
1389                // Create some diagnostics
1390                lsp_store
1391                    .update_diagnostics(
1392                        LanguageServerId(0),
1393                        lsp::PublishDiagnosticsParams {
1394                            uri: path_to_buffer_uri.clone(),
1395                            diagnostics: vec![diagnostic],
1396                            version: None,
1397                        },
1398                        None,
1399                        language::DiagnosticSourceKind::Pushed,
1400                        &[],
1401                        cx,
1402                    )
1403                    .unwrap();
1404            });
1405        });
1406
1407        let buffer = project
1408            .update(cx, |project, cx| {
1409                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1410                project.open_buffer(path, cx)
1411            })
1412            .await
1413            .unwrap();
1414
1415        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1416        let position = snapshot.anchor_before(language::Point::new(0, 0));
1417
1418        let _prediction_task = zeta.update(cx, |zeta, cx| {
1419            zeta.request_prediction(&project, &buffer, position, cx)
1420        });
1421
1422        let (request, _respond_tx) = req_rx.next().await.unwrap();
1423
1424        assert_eq!(request.diagnostic_groups.len(), 1);
1425        let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1426            .unwrap();
1427        // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1428        assert_eq!(
1429            value,
1430            json!({
1431                "entries": [{
1432                    "range": {
1433                        "start": 8,
1434                        "end": 10
1435                    },
1436                    "diagnostic": {
1437                        "source": null,
1438                        "code": null,
1439                        "code_description": null,
1440                        "severity": 1,
1441                        "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1442                        "markdown": null,
1443                        "group_id": 0,
1444                        "is_primary": true,
1445                        "is_disk_based": false,
1446                        "is_unnecessary": false,
1447                        "source_kind": "Pushed",
1448                        "data": null,
1449                        "underline": true
1450                    }
1451                }],
1452                "primary_ix": 0
1453            })
1454        );
1455    }
1456
1457    fn init_test(
1458        cx: &mut TestAppContext,
1459    ) -> (
1460        Entity<Zeta>,
1461        mpsc::UnboundedReceiver<(
1462            predict_edits_v3::PredictEditsRequest,
1463            oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1464        )>,
1465    ) {
1466        cx.update(move |cx| {
1467            let settings_store = SettingsStore::test(cx);
1468            cx.set_global(settings_store);
1469            language::init(cx);
1470            Project::init_settings(cx);
1471
1472            let (req_tx, req_rx) = mpsc::unbounded();
1473
1474            let http_client = FakeHttpClient::create({
1475                move |req| {
1476                    let uri = req.uri().path().to_string();
1477                    let mut body = req.into_body();
1478                    let req_tx = req_tx.clone();
1479                    async move {
1480                        let resp = match uri.as_str() {
1481                            "/client/llm_tokens" => serde_json::to_string(&json!({
1482                                "token": "test"
1483                            }))
1484                            .unwrap(),
1485                            "/predict_edits/v3" => {
1486                                let mut buf = Vec::new();
1487                                body.read_to_end(&mut buf).await.ok();
1488                                let req = serde_json::from_slice(&buf).unwrap();
1489
1490                                let (res_tx, res_rx) = oneshot::channel();
1491                                req_tx.unbounded_send((req, res_tx)).unwrap();
1492                                serde_json::to_string(&res_rx.await?).unwrap()
1493                            }
1494                            _ => {
1495                                panic!("Unexpected path: {}", uri)
1496                            }
1497                        };
1498
1499                        Ok(Response::builder().body(resp.into()).unwrap())
1500                    }
1501                }
1502            });
1503
1504            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1505            client.cloud_client().set_credentials(1, "test".into());
1506
1507            language_model::init(client.clone(), cx);
1508
1509            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1510            let zeta = Zeta::global(&client, &user_store, cx);
1511
1512            (zeta, req_rx)
1513        })
1514    }
1515}