zeta2.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use chrono::TimeDelta;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
   7    ZED_VERSION_HEADER_NAME,
   8};
   9use cloud_zeta2_prompt::{DEFAULT_MAX_PROMPT_BYTES, PlannedPrompt};
  10use edit_prediction_context::{
  11    DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions,
  12    EditPredictionExcerptOptions, EditPredictionScoreOptions, SyntaxIndex, SyntaxIndexState,
  13};
  14use futures::AsyncReadExt as _;
  15use futures::channel::{mpsc, oneshot};
  16use gpui::http_client::{AsyncBody, Method};
  17use gpui::{
  18    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  19    http_client, prelude::*,
  20};
  21use language::BufferSnapshot;
  22use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
  23use language_model::{LlmApiToken, RefreshLlmTokenListener};
  24use project::Project;
  25use release_channel::AppVersion;
  26use serde::de::DeserializeOwned;
  27use std::collections::{HashMap, VecDeque, hash_map};
  28use std::path::Path;
  29use std::str::FromStr as _;
  30use std::sync::Arc;
  31use std::time::{Duration, Instant};
  32use thiserror::Error;
  33use util::rel_path::RelPathBuf;
  34use util::some_or_debug_panic;
  35use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  36
  37mod prediction;
  38mod provider;
  39
  40use crate::prediction::EditPrediction;
  41pub use provider::ZetaEditPredictionProvider;
  42
  43const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  44
  45/// Maximum number of events to track.
  46const MAX_EVENT_COUNT: usize = 16;
  47
  48pub const DEFAULT_CONTEXT_OPTIONS: EditPredictionContextOptions = EditPredictionContextOptions {
  49    use_imports: true,
  50    use_references: false,
  51    excerpt: EditPredictionExcerptOptions {
  52        max_bytes: 512,
  53        min_bytes: 128,
  54        target_before_cursor_over_total_bytes: 0.5,
  55    },
  56    score: EditPredictionScoreOptions {
  57        omit_excerpt_overlaps: true,
  58    },
  59};
  60
  61pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  62    context: DEFAULT_CONTEXT_OPTIONS,
  63    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
  64    max_diagnostic_bytes: 2048,
  65    prompt_format: PromptFormat::DEFAULT,
  66    file_indexing_parallelism: 1,
  67};
  68
  69#[derive(Clone)]
  70struct ZetaGlobal(Entity<Zeta>);
  71
  72impl Global for ZetaGlobal {}
  73
  74pub struct Zeta {
  75    client: Arc<Client>,
  76    user_store: Entity<UserStore>,
  77    llm_token: LlmApiToken,
  78    _llm_token_subscription: Subscription,
  79    projects: HashMap<EntityId, ZetaProject>,
  80    options: ZetaOptions,
  81    update_required: bool,
  82    debug_tx: Option<mpsc::UnboundedSender<PredictionDebugInfo>>,
  83}
  84
  85#[derive(Debug, Clone, PartialEq)]
  86pub struct ZetaOptions {
  87    pub context: EditPredictionContextOptions,
  88    pub max_prompt_bytes: usize,
  89    pub max_diagnostic_bytes: usize,
  90    pub prompt_format: predict_edits_v3::PromptFormat,
  91    pub file_indexing_parallelism: usize,
  92}
  93
  94pub struct PredictionDebugInfo {
  95    pub context: EditPredictionContext,
  96    pub retrieval_time: TimeDelta,
  97    pub buffer: WeakEntity<Buffer>,
  98    pub position: language::Anchor,
  99    pub local_prompt: Result<String, String>,
 100    pub response_rx: oneshot::Receiver<Result<RequestDebugInfo, String>>,
 101}
 102
 103pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
 104
 105struct ZetaProject {
 106    syntax_index: Entity<SyntaxIndex>,
 107    events: VecDeque<Event>,
 108    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
 109    current_prediction: Option<CurrentEditPrediction>,
 110}
 111
 112#[derive(Debug, Clone)]
 113struct CurrentEditPrediction {
 114    pub requested_by_buffer_id: EntityId,
 115    pub prediction: EditPrediction,
 116}
 117
 118impl CurrentEditPrediction {
 119    fn should_replace_prediction(&self, old_prediction: &Self, cx: &App) -> bool {
 120        let Some(new_edits) = self
 121            .prediction
 122            .interpolate(&self.prediction.buffer.read(cx))
 123        else {
 124            return false;
 125        };
 126
 127        if self.prediction.buffer != old_prediction.prediction.buffer {
 128            return true;
 129        }
 130
 131        let Some(old_edits) = old_prediction
 132            .prediction
 133            .interpolate(&old_prediction.prediction.buffer.read(cx))
 134        else {
 135            return true;
 136        };
 137
 138        // This reduces the occurrence of UI thrash from replacing edits
 139        //
 140        // TODO: This is fairly arbitrary - should have a more general heuristic that handles multiple edits.
 141        if self.requested_by_buffer_id == self.prediction.buffer.entity_id()
 142            && self.requested_by_buffer_id == old_prediction.prediction.buffer.entity_id()
 143            && old_edits.len() == 1
 144            && new_edits.len() == 1
 145        {
 146            let (old_range, old_text) = &old_edits[0];
 147            let (new_range, new_text) = &new_edits[0];
 148            new_range == old_range && new_text.starts_with(old_text)
 149        } else {
 150            true
 151        }
 152    }
 153}
 154
 155/// A prediction from the perspective of a buffer.
 156#[derive(Debug)]
 157enum BufferEditPrediction<'a> {
 158    Local { prediction: &'a EditPrediction },
 159    Jump { prediction: &'a EditPrediction },
 160}
 161
 162struct RegisteredBuffer {
 163    snapshot: BufferSnapshot,
 164    _subscriptions: [gpui::Subscription; 2],
 165}
 166
 167#[derive(Clone)]
 168pub enum Event {
 169    BufferChange {
 170        old_snapshot: BufferSnapshot,
 171        new_snapshot: BufferSnapshot,
 172        timestamp: Instant,
 173    },
 174}
 175
 176impl Zeta {
 177    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 178        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 179    }
 180
 181    pub fn global(
 182        client: &Arc<Client>,
 183        user_store: &Entity<UserStore>,
 184        cx: &mut App,
 185    ) -> Entity<Self> {
 186        cx.try_global::<ZetaGlobal>()
 187            .map(|global| global.0.clone())
 188            .unwrap_or_else(|| {
 189                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 190                cx.set_global(ZetaGlobal(zeta.clone()));
 191                zeta
 192            })
 193    }
 194
 195    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 196        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 197
 198        Self {
 199            projects: HashMap::new(),
 200            client,
 201            user_store,
 202            options: DEFAULT_OPTIONS,
 203            llm_token: LlmApiToken::default(),
 204            _llm_token_subscription: cx.subscribe(
 205                &refresh_llm_token_listener,
 206                |this, _listener, _event, cx| {
 207                    let client = this.client.clone();
 208                    let llm_token = this.llm_token.clone();
 209                    cx.spawn(async move |_this, _cx| {
 210                        llm_token.refresh(&client).await?;
 211                        anyhow::Ok(())
 212                    })
 213                    .detach_and_log_err(cx);
 214                },
 215            ),
 216            update_required: false,
 217            debug_tx: None,
 218        }
 219    }
 220
 221    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<PredictionDebugInfo> {
 222        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 223        self.debug_tx = Some(debug_watch_tx);
 224        debug_watch_rx
 225    }
 226
 227    pub fn options(&self) -> &ZetaOptions {
 228        &self.options
 229    }
 230
 231    pub fn set_options(&mut self, options: ZetaOptions) {
 232        self.options = options;
 233    }
 234
 235    pub fn clear_history(&mut self) {
 236        for zeta_project in self.projects.values_mut() {
 237            zeta_project.events.clear();
 238        }
 239    }
 240
 241    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 242        self.user_store.read(cx).edit_prediction_usage()
 243    }
 244
 245    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 246        self.get_or_init_zeta_project(project, cx);
 247    }
 248
 249    pub fn register_buffer(
 250        &mut self,
 251        buffer: &Entity<Buffer>,
 252        project: &Entity<Project>,
 253        cx: &mut Context<Self>,
 254    ) {
 255        let zeta_project = self.get_or_init_zeta_project(project, cx);
 256        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 257    }
 258
 259    fn get_or_init_zeta_project(
 260        &mut self,
 261        project: &Entity<Project>,
 262        cx: &mut App,
 263    ) -> &mut ZetaProject {
 264        self.projects
 265            .entry(project.entity_id())
 266            .or_insert_with(|| ZetaProject {
 267                syntax_index: cx.new(|cx| {
 268                    SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 269                }),
 270                events: VecDeque::new(),
 271                registered_buffers: HashMap::new(),
 272                current_prediction: None,
 273            })
 274    }
 275
 276    fn register_buffer_impl<'a>(
 277        zeta_project: &'a mut ZetaProject,
 278        buffer: &Entity<Buffer>,
 279        project: &Entity<Project>,
 280        cx: &mut Context<Self>,
 281    ) -> &'a mut RegisteredBuffer {
 282        let buffer_id = buffer.entity_id();
 283        match zeta_project.registered_buffers.entry(buffer_id) {
 284            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 285            hash_map::Entry::Vacant(entry) => {
 286                let snapshot = buffer.read(cx).snapshot();
 287                let project_entity_id = project.entity_id();
 288                entry.insert(RegisteredBuffer {
 289                    snapshot,
 290                    _subscriptions: [
 291                        cx.subscribe(buffer, {
 292                            let project = project.downgrade();
 293                            move |this, buffer, event, cx| {
 294                                if let language::BufferEvent::Edited = event
 295                                    && let Some(project) = project.upgrade()
 296                                {
 297                                    this.report_changes_for_buffer(&buffer, &project, cx);
 298                                }
 299                            }
 300                        }),
 301                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 302                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 303                            else {
 304                                return;
 305                            };
 306                            zeta_project.registered_buffers.remove(&buffer_id);
 307                        }),
 308                    ],
 309                })
 310            }
 311        }
 312    }
 313
 314    fn report_changes_for_buffer(
 315        &mut self,
 316        buffer: &Entity<Buffer>,
 317        project: &Entity<Project>,
 318        cx: &mut Context<Self>,
 319    ) -> BufferSnapshot {
 320        let zeta_project = self.get_or_init_zeta_project(project, cx);
 321        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 322
 323        let new_snapshot = buffer.read(cx).snapshot();
 324        if new_snapshot.version != registered_buffer.snapshot.version {
 325            let old_snapshot =
 326                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 327            Self::push_event(
 328                zeta_project,
 329                Event::BufferChange {
 330                    old_snapshot,
 331                    new_snapshot: new_snapshot.clone(),
 332                    timestamp: Instant::now(),
 333                },
 334            );
 335        }
 336
 337        new_snapshot
 338    }
 339
 340    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
 341        let events = &mut zeta_project.events;
 342
 343        if let Some(Event::BufferChange {
 344            new_snapshot: last_new_snapshot,
 345            timestamp: last_timestamp,
 346            ..
 347        }) = events.back_mut()
 348        {
 349            // Coalesce edits for the same buffer when they happen one after the other.
 350            let Event::BufferChange {
 351                old_snapshot,
 352                new_snapshot,
 353                timestamp,
 354            } = &event;
 355
 356            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 357                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 358                && old_snapshot.version == last_new_snapshot.version
 359            {
 360                *last_new_snapshot = new_snapshot.clone();
 361                *last_timestamp = *timestamp;
 362                return;
 363            }
 364        }
 365
 366        if events.len() >= MAX_EVENT_COUNT {
 367            // These are halved instead of popping to improve prompt caching.
 368            events.drain(..MAX_EVENT_COUNT / 2);
 369        }
 370
 371        events.push_back(event);
 372    }
 373
 374    fn current_prediction_for_buffer(
 375        &self,
 376        buffer: &Entity<Buffer>,
 377        project: &Entity<Project>,
 378        cx: &App,
 379    ) -> Option<BufferEditPrediction<'_>> {
 380        let project_state = self.projects.get(&project.entity_id())?;
 381
 382        let CurrentEditPrediction {
 383            requested_by_buffer_id,
 384            prediction,
 385        } = project_state.current_prediction.as_ref()?;
 386
 387        if prediction.targets_buffer(buffer.read(cx), cx) {
 388            Some(BufferEditPrediction::Local { prediction })
 389        } else if *requested_by_buffer_id == buffer.entity_id() {
 390            Some(BufferEditPrediction::Jump { prediction })
 391        } else {
 392            None
 393        }
 394    }
 395
 396    fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
 397        let Some(project_state) = self.projects.get_mut(&project.entity_id()) else {
 398            return;
 399        };
 400
 401        let Some(prediction) = project_state.current_prediction.take() else {
 402            return;
 403        };
 404        let request_id = prediction.prediction.id.into();
 405
 406        let client = self.client.clone();
 407        let llm_token = self.llm_token.clone();
 408        let app_version = AppVersion::global(cx);
 409        cx.spawn(async move |this, cx| {
 410            let url = if let Ok(predict_edits_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
 411                http_client::Url::parse(&predict_edits_url)?
 412            } else {
 413                client
 414                    .http_client()
 415                    .build_zed_llm_url("/predict_edits/accept", &[])?
 416            };
 417
 418            let response = cx
 419                .background_spawn(Self::send_api_request::<()>(
 420                    move |builder| {
 421                        let req = builder.uri(url.as_ref()).body(
 422                            serde_json::to_string(&AcceptEditPredictionBody { request_id })?.into(),
 423                        );
 424                        Ok(req?)
 425                    },
 426                    client,
 427                    llm_token,
 428                    app_version,
 429                ))
 430                .await;
 431
 432            Self::handle_api_response(&this, response, cx)?;
 433            anyhow::Ok(())
 434        })
 435        .detach_and_log_err(cx);
 436    }
 437
 438    fn discard_current_prediction(&mut self, project: &Entity<Project>) {
 439        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 440            project_state.current_prediction.take();
 441        };
 442    }
 443
 444    pub fn refresh_prediction(
 445        &mut self,
 446        project: &Entity<Project>,
 447        buffer: &Entity<Buffer>,
 448        position: language::Anchor,
 449        cx: &mut Context<Self>,
 450    ) -> Task<Result<()>> {
 451        let request_task = self.request_prediction(project, buffer, position, cx);
 452        let buffer = buffer.clone();
 453        let project = project.clone();
 454
 455        cx.spawn(async move |this, cx| {
 456            if let Some(prediction) = request_task.await? {
 457                this.update(cx, |this, cx| {
 458                    let project_state = this
 459                        .projects
 460                        .get_mut(&project.entity_id())
 461                        .context("Project not found")?;
 462
 463                    let new_prediction = CurrentEditPrediction {
 464                        requested_by_buffer_id: buffer.entity_id(),
 465                        prediction: prediction,
 466                    };
 467
 468                    if project_state
 469                        .current_prediction
 470                        .as_ref()
 471                        .is_none_or(|old_prediction| {
 472                            new_prediction.should_replace_prediction(&old_prediction, cx)
 473                        })
 474                    {
 475                        project_state.current_prediction = Some(new_prediction);
 476                    }
 477                    anyhow::Ok(())
 478                })??;
 479            }
 480            Ok(())
 481        })
 482    }
 483
 484    fn request_prediction(
 485        &mut self,
 486        project: &Entity<Project>,
 487        buffer: &Entity<Buffer>,
 488        position: language::Anchor,
 489        cx: &mut Context<Self>,
 490    ) -> Task<Result<Option<EditPrediction>>> {
 491        let project_state = self.projects.get(&project.entity_id());
 492
 493        let index_state = project_state.map(|state| {
 494            state
 495                .syntax_index
 496                .read_with(cx, |index, _cx| index.state().clone())
 497        });
 498        let options = self.options.clone();
 499        let snapshot = buffer.read(cx).snapshot();
 500        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx).into()) else {
 501            return Task::ready(Err(anyhow!("No file path for excerpt")));
 502        };
 503        let client = self.client.clone();
 504        let llm_token = self.llm_token.clone();
 505        let app_version = AppVersion::global(cx);
 506        let worktree_snapshots = project
 507            .read(cx)
 508            .worktrees(cx)
 509            .map(|worktree| worktree.read(cx).snapshot())
 510            .collect::<Vec<_>>();
 511        let debug_tx = self.debug_tx.clone();
 512
 513        let events = project_state
 514            .map(|state| {
 515                state
 516                    .events
 517                    .iter()
 518                    .filter_map(|event| match event {
 519                        Event::BufferChange {
 520                            old_snapshot,
 521                            new_snapshot,
 522                            ..
 523                        } => {
 524                            let path = new_snapshot.file().map(|f| f.full_path(cx));
 525
 526                            let old_path = old_snapshot.file().and_then(|f| {
 527                                let old_path = f.full_path(cx);
 528                                if Some(&old_path) != path.as_ref() {
 529                                    Some(old_path)
 530                                } else {
 531                                    None
 532                                }
 533                            });
 534
 535                            // TODO [zeta2] move to bg?
 536                            let diff =
 537                                language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
 538
 539                            if path == old_path && diff.is_empty() {
 540                                None
 541                            } else {
 542                                Some(predict_edits_v3::Event::BufferChange {
 543                                    old_path,
 544                                    path,
 545                                    diff,
 546                                    //todo: Actually detect if this edit was predicted or not
 547                                    predicted: false,
 548                                })
 549                            }
 550                        }
 551                    })
 552                    .collect::<Vec<_>>()
 553            })
 554            .unwrap_or_default();
 555
 556        let diagnostics = snapshot.diagnostic_sets().clone();
 557
 558        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
 559            let mut path = f.worktree.read(cx).absolutize(&f.path);
 560            if path.pop() { Some(path) } else { None }
 561        });
 562
 563        let request_task = cx.background_spawn({
 564            let snapshot = snapshot.clone();
 565            let buffer = buffer.clone();
 566            async move {
 567                let index_state = if let Some(index_state) = index_state {
 568                    Some(index_state.lock_owned().await)
 569                } else {
 570                    None
 571                };
 572
 573                let cursor_offset = position.to_offset(&snapshot);
 574                let cursor_point = cursor_offset.to_point(&snapshot);
 575
 576                let before_retrieval = chrono::Utc::now();
 577
 578                let Some(context) = EditPredictionContext::gather_context(
 579                    cursor_point,
 580                    &snapshot,
 581                    parent_abs_path.as_deref(),
 582                    &options.context,
 583                    index_state.as_deref(),
 584                ) else {
 585                    return Ok((None, None));
 586                };
 587
 588                let retrieval_time = chrono::Utc::now() - before_retrieval;
 589
 590                let (diagnostic_groups, diagnostic_groups_truncated) =
 591                    Self::gather_nearby_diagnostics(
 592                        cursor_offset,
 593                        &diagnostics,
 594                        &snapshot,
 595                        options.max_diagnostic_bytes,
 596                    );
 597
 598                let debug_context = debug_tx.map(|tx| (tx, context.clone()));
 599
 600                let request = make_cloud_request(
 601                    excerpt_path,
 602                    context,
 603                    events,
 604                    // TODO data collection
 605                    false,
 606                    diagnostic_groups,
 607                    diagnostic_groups_truncated,
 608                    None,
 609                    debug_context.is_some(),
 610                    &worktree_snapshots,
 611                    index_state.as_deref(),
 612                    Some(options.max_prompt_bytes),
 613                    options.prompt_format,
 614                );
 615
 616                let debug_response_tx = if let Some((debug_tx, context)) = debug_context {
 617                    let (response_tx, response_rx) = oneshot::channel();
 618
 619                    let local_prompt = PlannedPrompt::populate(&request)
 620                        .and_then(|p| p.to_prompt_string().map(|p| p.0))
 621                        .map_err(|err| err.to_string());
 622
 623                    debug_tx
 624                        .unbounded_send(PredictionDebugInfo {
 625                            context,
 626                            retrieval_time,
 627                            buffer: buffer.downgrade(),
 628                            local_prompt,
 629                            position,
 630                            response_rx,
 631                        })
 632                        .ok();
 633                    Some(response_tx)
 634                } else {
 635                    None
 636                };
 637
 638                if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
 639                    if let Some(debug_response_tx) = debug_response_tx {
 640                        debug_response_tx
 641                            .send(Err("Request skipped".to_string()))
 642                            .ok();
 643                    }
 644                    anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
 645                }
 646
 647                let response =
 648                    Self::send_prediction_request(client, llm_token, app_version, request).await;
 649
 650                if let Some(debug_response_tx) = debug_response_tx {
 651                    debug_response_tx
 652                        .send(response.as_ref().map_err(|err| err.to_string()).and_then(
 653                            |response| match some_or_debug_panic(response.0.debug_info.clone()) {
 654                                Some(debug_info) => Ok(debug_info),
 655                                None => Err("Missing debug info".to_string()),
 656                            },
 657                        ))
 658                        .ok();
 659                }
 660
 661                response.map(|(res, usage)| (Some(res), usage))
 662            }
 663        });
 664
 665        let buffer = buffer.clone();
 666
 667        cx.spawn({
 668            let project = project.clone();
 669            async move |this, cx| {
 670                let Some(response) = Self::handle_api_response(&this, request_task.await, cx)?
 671                else {
 672                    return Ok(None);
 673                };
 674
 675                // TODO telemetry: duration, etc
 676                Ok(EditPrediction::from_response(response, &snapshot, &buffer, &project, cx).await)
 677            }
 678        })
 679    }
 680
 681    async fn send_prediction_request(
 682        client: Arc<Client>,
 683        llm_token: LlmApiToken,
 684        app_version: SemanticVersion,
 685        request: predict_edits_v3::PredictEditsRequest,
 686    ) -> Result<(
 687        predict_edits_v3::PredictEditsResponse,
 688        Option<EditPredictionUsage>,
 689    )> {
 690        let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 691            http_client::Url::parse(&predict_edits_url)?
 692        } else {
 693            client
 694                .http_client()
 695                .build_zed_llm_url("/predict_edits/v3", &[])?
 696        };
 697
 698        Self::send_api_request(
 699            |builder| {
 700                let req = builder
 701                    .uri(url.as_ref())
 702                    .body(serde_json::to_string(&request)?.into());
 703                Ok(req?)
 704            },
 705            client,
 706            llm_token,
 707            app_version,
 708        )
 709        .await
 710    }
 711
 712    fn handle_api_response<T>(
 713        this: &WeakEntity<Self>,
 714        response: Result<(T, Option<EditPredictionUsage>)>,
 715        cx: &mut gpui::AsyncApp,
 716    ) -> Result<T> {
 717        match response {
 718            Ok((data, usage)) => {
 719                if let Some(usage) = usage {
 720                    this.update(cx, |this, cx| {
 721                        this.user_store.update(cx, |user_store, cx| {
 722                            user_store.update_edit_prediction_usage(usage, cx);
 723                        });
 724                    })
 725                    .ok();
 726                }
 727                Ok(data)
 728            }
 729            Err(err) => {
 730                if err.is::<ZedUpdateRequiredError>() {
 731                    cx.update(|cx| {
 732                        this.update(cx, |this, _cx| {
 733                            this.update_required = true;
 734                        })
 735                        .ok();
 736
 737                        let error_message: SharedString = err.to_string().into();
 738                        show_app_notification(
 739                            NotificationId::unique::<ZedUpdateRequiredError>(),
 740                            cx,
 741                            move |cx| {
 742                                cx.new(|cx| {
 743                                    ErrorMessagePrompt::new(error_message.clone(), cx)
 744                                        .with_link_button("Update Zed", "https://zed.dev/releases")
 745                                })
 746                            },
 747                        );
 748                    })
 749                    .ok();
 750                }
 751                Err(err)
 752            }
 753        }
 754    }
 755
 756    async fn send_api_request<Res>(
 757        build: impl Fn(http_client::http::request::Builder) -> Result<http_client::Request<AsyncBody>>,
 758        client: Arc<Client>,
 759        llm_token: LlmApiToken,
 760        app_version: SemanticVersion,
 761    ) -> Result<(Res, Option<EditPredictionUsage>)>
 762    where
 763        Res: DeserializeOwned,
 764    {
 765        let http_client = client.http_client();
 766        let mut token = llm_token.acquire(&client).await?;
 767        let mut did_retry = false;
 768
 769        loop {
 770            let request_builder = http_client::Request::builder().method(Method::POST);
 771
 772            let request = build(
 773                request_builder
 774                    .header("Content-Type", "application/json")
 775                    .header("Authorization", format!("Bearer {}", token))
 776                    .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
 777            )?;
 778
 779            let mut response = http_client.send(request).await?;
 780
 781            if let Some(minimum_required_version) = response
 782                .headers()
 783                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 784                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 785            {
 786                anyhow::ensure!(
 787                    app_version >= minimum_required_version,
 788                    ZedUpdateRequiredError {
 789                        minimum_version: minimum_required_version
 790                    }
 791                );
 792            }
 793
 794            if response.status().is_success() {
 795                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 796
 797                let mut body = Vec::new();
 798                response.body_mut().read_to_end(&mut body).await?;
 799                return Ok((serde_json::from_slice(&body)?, usage));
 800            } else if !did_retry
 801                && response
 802                    .headers()
 803                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 804                    .is_some()
 805            {
 806                did_retry = true;
 807                token = llm_token.refresh(&client).await?;
 808            } else {
 809                let mut body = String::new();
 810                response.body_mut().read_to_string(&mut body).await?;
 811                anyhow::bail!(
 812                    "Request failed with status: {:?}\nBody: {}",
 813                    response.status(),
 814                    body
 815                );
 816            }
 817        }
 818    }
 819
 820    fn gather_nearby_diagnostics(
 821        cursor_offset: usize,
 822        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
 823        snapshot: &BufferSnapshot,
 824        max_diagnostics_bytes: usize,
 825    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
 826        // TODO: Could make this more efficient
 827        let mut diagnostic_groups = Vec::new();
 828        for (language_server_id, diagnostics) in diagnostic_sets {
 829            let mut groups = Vec::new();
 830            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
 831            diagnostic_groups.extend(
 832                groups
 833                    .into_iter()
 834                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
 835            );
 836        }
 837
 838        // sort by proximity to cursor
 839        diagnostic_groups.sort_by_key(|group| {
 840            let range = &group.entries[group.primary_ix].range;
 841            if range.start >= cursor_offset {
 842                range.start - cursor_offset
 843            } else if cursor_offset >= range.end {
 844                cursor_offset - range.end
 845            } else {
 846                (cursor_offset - range.start).min(range.end - cursor_offset)
 847            }
 848        });
 849
 850        let mut results = Vec::new();
 851        let mut diagnostic_groups_truncated = false;
 852        let mut diagnostics_byte_count = 0;
 853        for group in diagnostic_groups {
 854            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
 855            diagnostics_byte_count += raw_value.get().len();
 856            if diagnostics_byte_count > max_diagnostics_bytes {
 857                diagnostic_groups_truncated = true;
 858                break;
 859            }
 860            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
 861        }
 862
 863        (results, diagnostic_groups_truncated)
 864    }
 865
 866    // TODO: Dedupe with similar code in request_prediction?
 867    pub fn cloud_request_for_zeta_cli(
 868        &mut self,
 869        project: &Entity<Project>,
 870        buffer: &Entity<Buffer>,
 871        position: language::Anchor,
 872        cx: &mut Context<Self>,
 873    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
 874        let project_state = self.projects.get(&project.entity_id());
 875
 876        let index_state = project_state.map(|state| {
 877            state
 878                .syntax_index
 879                .read_with(cx, |index, _cx| index.state().clone())
 880        });
 881        let options = self.options.clone();
 882        let snapshot = buffer.read(cx).snapshot();
 883        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
 884            return Task::ready(Err(anyhow!("No file path for excerpt")));
 885        };
 886        let worktree_snapshots = project
 887            .read(cx)
 888            .worktrees(cx)
 889            .map(|worktree| worktree.read(cx).snapshot())
 890            .collect::<Vec<_>>();
 891
 892        let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| {
 893            let mut path = f.worktree.read(cx).absolutize(&f.path);
 894            if path.pop() { Some(path) } else { None }
 895        });
 896
 897        cx.background_spawn(async move {
 898            let index_state = if let Some(index_state) = index_state {
 899                Some(index_state.lock_owned().await)
 900            } else {
 901                None
 902            };
 903
 904            let cursor_point = position.to_point(&snapshot);
 905
 906            let debug_info = true;
 907            EditPredictionContext::gather_context(
 908                cursor_point,
 909                &snapshot,
 910                parent_abs_path.as_deref(),
 911                &options.context,
 912                index_state.as_deref(),
 913            )
 914            .context("Failed to select excerpt")
 915            .map(|context| {
 916                make_cloud_request(
 917                    excerpt_path.into(),
 918                    context,
 919                    // TODO pass everything
 920                    Vec::new(),
 921                    false,
 922                    Vec::new(),
 923                    false,
 924                    None,
 925                    debug_info,
 926                    &worktree_snapshots,
 927                    index_state.as_deref(),
 928                    Some(options.max_prompt_bytes),
 929                    options.prompt_format,
 930                )
 931            })
 932        })
 933    }
 934
 935    pub fn wait_for_initial_indexing(
 936        &mut self,
 937        project: &Entity<Project>,
 938        cx: &mut App,
 939    ) -> Task<Result<()>> {
 940        let zeta_project = self.get_or_init_zeta_project(project, cx);
 941        zeta_project
 942            .syntax_index
 943            .read(cx)
 944            .wait_for_initial_file_indexing(cx)
 945    }
 946}
 947
 948#[derive(Error, Debug)]
 949#[error(
 950    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
 951)]
 952pub struct ZedUpdateRequiredError {
 953    minimum_version: SemanticVersion,
 954}
 955
 956fn make_cloud_request(
 957    excerpt_path: Arc<Path>,
 958    context: EditPredictionContext,
 959    events: Vec<predict_edits_v3::Event>,
 960    can_collect_data: bool,
 961    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
 962    diagnostic_groups_truncated: bool,
 963    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
 964    debug_info: bool,
 965    worktrees: &Vec<worktree::Snapshot>,
 966    index_state: Option<&SyntaxIndexState>,
 967    prompt_max_bytes: Option<usize>,
 968    prompt_format: PromptFormat,
 969) -> predict_edits_v3::PredictEditsRequest {
 970    let mut signatures = Vec::new();
 971    let mut declaration_to_signature_index = HashMap::default();
 972    let mut referenced_declarations = Vec::new();
 973
 974    for snippet in context.declarations {
 975        let project_entry_id = snippet.declaration.project_entry_id();
 976        let Some(path) = worktrees.iter().find_map(|worktree| {
 977            worktree.entry_for_id(project_entry_id).map(|entry| {
 978                let mut full_path = RelPathBuf::new();
 979                full_path.push(worktree.root_name());
 980                full_path.push(&entry.path);
 981                full_path
 982            })
 983        }) else {
 984            continue;
 985        };
 986
 987        let parent_index = index_state.and_then(|index_state| {
 988            snippet.declaration.parent().and_then(|parent| {
 989                add_signature(
 990                    parent,
 991                    &mut declaration_to_signature_index,
 992                    &mut signatures,
 993                    index_state,
 994                )
 995            })
 996        });
 997
 998        let (text, text_is_truncated) = snippet.declaration.item_text();
 999        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1000            path: path.as_std_path().into(),
1001            text: text.into(),
1002            range: snippet.declaration.item_line_range(),
1003            text_is_truncated,
1004            signature_range: snippet.declaration.signature_range_in_item_text(),
1005            parent_index,
1006            signature_score: snippet.score(DeclarationStyle::Signature),
1007            declaration_score: snippet.score(DeclarationStyle::Declaration),
1008            score_components: snippet.components,
1009        });
1010    }
1011
1012    let excerpt_parent = index_state.and_then(|index_state| {
1013        context
1014            .excerpt
1015            .parent_declarations
1016            .last()
1017            .and_then(|(parent, _)| {
1018                add_signature(
1019                    *parent,
1020                    &mut declaration_to_signature_index,
1021                    &mut signatures,
1022                    index_state,
1023                )
1024            })
1025    });
1026
1027    predict_edits_v3::PredictEditsRequest {
1028        excerpt_path,
1029        excerpt: context.excerpt_text.body,
1030        excerpt_line_range: context.excerpt.line_range,
1031        excerpt_range: context.excerpt.range,
1032        cursor_point: predict_edits_v3::Point {
1033            line: predict_edits_v3::Line(context.cursor_point.row),
1034            column: context.cursor_point.column,
1035        },
1036        referenced_declarations,
1037        signatures,
1038        excerpt_parent,
1039        events,
1040        can_collect_data,
1041        diagnostic_groups,
1042        diagnostic_groups_truncated,
1043        git_info,
1044        debug_info,
1045        prompt_max_bytes,
1046        prompt_format,
1047    }
1048}
1049
1050fn add_signature(
1051    declaration_id: DeclarationId,
1052    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1053    signatures: &mut Vec<Signature>,
1054    index: &SyntaxIndexState,
1055) -> Option<usize> {
1056    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1057        return Some(*signature_index);
1058    }
1059    let Some(parent_declaration) = index.declaration(declaration_id) else {
1060        log::error!("bug: missing parent declaration");
1061        return None;
1062    };
1063    let parent_index = parent_declaration.parent().and_then(|parent| {
1064        add_signature(parent, declaration_to_signature_index, signatures, index)
1065    });
1066    let (text, text_is_truncated) = parent_declaration.signature_text();
1067    let signature_index = signatures.len();
1068    signatures.push(Signature {
1069        text: text.into(),
1070        text_is_truncated,
1071        parent_index,
1072        range: parent_declaration.signature_line_range(),
1073    });
1074    declaration_to_signature_index.insert(declaration_id, signature_index);
1075    Some(signature_index)
1076}
1077
1078#[cfg(test)]
1079mod tests {
1080    use std::{
1081        path::{Path, PathBuf},
1082        sync::Arc,
1083    };
1084
1085    use client::UserStore;
1086    use clock::FakeSystemClock;
1087    use cloud_llm_client::predict_edits_v3::{self, Point};
1088    use edit_prediction_context::Line;
1089    use futures::{
1090        AsyncReadExt, StreamExt,
1091        channel::{mpsc, oneshot},
1092    };
1093    use gpui::{
1094        Entity, TestAppContext,
1095        http_client::{FakeHttpClient, Response},
1096        prelude::*,
1097    };
1098    use indoc::indoc;
1099    use language::{LanguageServerId, OffsetRangeExt as _};
1100    use pretty_assertions::{assert_eq, assert_matches};
1101    use project::{FakeFs, Project};
1102    use serde_json::json;
1103    use settings::SettingsStore;
1104    use util::path;
1105    use uuid::Uuid;
1106
1107    use crate::{BufferEditPrediction, Zeta};
1108
1109    #[gpui::test]
1110    async fn test_current_state(cx: &mut TestAppContext) {
1111        let (zeta, mut req_rx) = init_test(cx);
1112        let fs = FakeFs::new(cx.executor());
1113        fs.insert_tree(
1114            "/root",
1115            json!({
1116                "1.txt": "Hello!\nHow\nBye",
1117                "2.txt": "Hola!\nComo\nAdios"
1118            }),
1119        )
1120        .await;
1121        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1122
1123        zeta.update(cx, |zeta, cx| {
1124            zeta.register_project(&project, cx);
1125        });
1126
1127        let buffer1 = project
1128            .update(cx, |project, cx| {
1129                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1130                project.open_buffer(path, cx)
1131            })
1132            .await
1133            .unwrap();
1134        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1135        let position = snapshot1.anchor_before(language::Point::new(1, 3));
1136
1137        // Prediction for current file
1138
1139        let prediction_task = zeta.update(cx, |zeta, cx| {
1140            zeta.refresh_prediction(&project, &buffer1, position, cx)
1141        });
1142        let (_request, respond_tx) = req_rx.next().await.unwrap();
1143        respond_tx
1144            .send(predict_edits_v3::PredictEditsResponse {
1145                request_id: Uuid::new_v4(),
1146                edits: vec![predict_edits_v3::Edit {
1147                    path: Path::new(path!("root/1.txt")).into(),
1148                    range: Line(0)..Line(snapshot1.max_point().row + 1),
1149                    content: "Hello!\nHow are you?\nBye".into(),
1150                }],
1151                debug_info: None,
1152            })
1153            .unwrap();
1154        prediction_task.await.unwrap();
1155
1156        zeta.read_with(cx, |zeta, cx| {
1157            let prediction = zeta
1158                .current_prediction_for_buffer(&buffer1, &project, cx)
1159                .unwrap();
1160            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1161        });
1162
1163        // Prediction for another file
1164        let prediction_task = zeta.update(cx, |zeta, cx| {
1165            zeta.refresh_prediction(&project, &buffer1, position, cx)
1166        });
1167        let (_request, respond_tx) = req_rx.next().await.unwrap();
1168        respond_tx
1169            .send(predict_edits_v3::PredictEditsResponse {
1170                request_id: Uuid::new_v4(),
1171                edits: vec![predict_edits_v3::Edit {
1172                    path: Path::new(path!("root/2.txt")).into(),
1173                    range: Line(0)..Line(snapshot1.max_point().row + 1),
1174                    content: "Hola!\nComo estas?\nAdios".into(),
1175                }],
1176                debug_info: None,
1177            })
1178            .unwrap();
1179        prediction_task.await.unwrap();
1180        zeta.read_with(cx, |zeta, cx| {
1181            let prediction = zeta
1182                .current_prediction_for_buffer(&buffer1, &project, cx)
1183                .unwrap();
1184            assert_matches!(
1185                prediction,
1186                BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1187            );
1188        });
1189
1190        let buffer2 = project
1191            .update(cx, |project, cx| {
1192                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1193                project.open_buffer(path, cx)
1194            })
1195            .await
1196            .unwrap();
1197
1198        zeta.read_with(cx, |zeta, cx| {
1199            let prediction = zeta
1200                .current_prediction_for_buffer(&buffer2, &project, cx)
1201                .unwrap();
1202            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1203        });
1204    }
1205
1206    #[gpui::test]
1207    async fn test_simple_request(cx: &mut TestAppContext) {
1208        let (zeta, mut req_rx) = init_test(cx);
1209        let fs = FakeFs::new(cx.executor());
1210        fs.insert_tree(
1211            "/root",
1212            json!({
1213                "foo.md":  "Hello!\nHow\nBye"
1214            }),
1215        )
1216        .await;
1217        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1218
1219        let buffer = project
1220            .update(cx, |project, cx| {
1221                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1222                project.open_buffer(path, cx)
1223            })
1224            .await
1225            .unwrap();
1226        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1227        let position = snapshot.anchor_before(language::Point::new(1, 3));
1228
1229        let prediction_task = zeta.update(cx, |zeta, cx| {
1230            zeta.request_prediction(&project, &buffer, position, cx)
1231        });
1232
1233        let (request, respond_tx) = req_rx.next().await.unwrap();
1234        assert_eq!(
1235            request.excerpt_path.as_ref(),
1236            Path::new(path!("root/foo.md"))
1237        );
1238        assert_eq!(
1239            request.cursor_point,
1240            Point {
1241                line: Line(1),
1242                column: 3
1243            }
1244        );
1245
1246        respond_tx
1247            .send(predict_edits_v3::PredictEditsResponse {
1248                request_id: Uuid::new_v4(),
1249                edits: vec![predict_edits_v3::Edit {
1250                    path: Path::new(path!("root/foo.md")).into(),
1251                    range: Line(0)..Line(snapshot.max_point().row + 1),
1252                    content: "Hello!\nHow are you?\nBye".into(),
1253                }],
1254                debug_info: None,
1255            })
1256            .unwrap();
1257
1258        let prediction = prediction_task.await.unwrap().unwrap();
1259
1260        assert_eq!(prediction.edits.len(), 1);
1261        assert_eq!(
1262            prediction.edits[0].0.to_point(&snapshot).start,
1263            language::Point::new(1, 3)
1264        );
1265        assert_eq!(prediction.edits[0].1, " are you?");
1266    }
1267
1268    #[gpui::test]
1269    async fn test_request_events(cx: &mut TestAppContext) {
1270        let (zeta, mut req_rx) = init_test(cx);
1271        let fs = FakeFs::new(cx.executor());
1272        fs.insert_tree(
1273            "/root",
1274            json!({
1275                "foo.md": "Hello!\n\nBye"
1276            }),
1277        )
1278        .await;
1279        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1280
1281        let buffer = project
1282            .update(cx, |project, cx| {
1283                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1284                project.open_buffer(path, cx)
1285            })
1286            .await
1287            .unwrap();
1288
1289        zeta.update(cx, |zeta, cx| {
1290            zeta.register_buffer(&buffer, &project, cx);
1291        });
1292
1293        buffer.update(cx, |buffer, cx| {
1294            buffer.edit(vec![(7..7, "How")], None, cx);
1295        });
1296
1297        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1298        let position = snapshot.anchor_before(language::Point::new(1, 3));
1299
1300        let prediction_task = zeta.update(cx, |zeta, cx| {
1301            zeta.request_prediction(&project, &buffer, position, cx)
1302        });
1303
1304        let (request, respond_tx) = req_rx.next().await.unwrap();
1305
1306        assert_eq!(request.events.len(), 1);
1307        assert_eq!(
1308            request.events[0],
1309            predict_edits_v3::Event::BufferChange {
1310                path: Some(PathBuf::from(path!("root/foo.md"))),
1311                old_path: None,
1312                diff: indoc! {"
1313                        @@ -1,3 +1,3 @@
1314                         Hello!
1315                        -
1316                        +How
1317                         Bye
1318                    "}
1319                .to_string(),
1320                predicted: false
1321            }
1322        );
1323
1324        respond_tx
1325            .send(predict_edits_v3::PredictEditsResponse {
1326                request_id: Uuid::new_v4(),
1327                edits: vec![predict_edits_v3::Edit {
1328                    path: Path::new(path!("root/foo.md")).into(),
1329                    range: Line(0)..Line(snapshot.max_point().row + 1),
1330                    content: "Hello!\nHow are you?\nBye".into(),
1331                }],
1332                debug_info: None,
1333            })
1334            .unwrap();
1335
1336        let prediction = prediction_task.await.unwrap().unwrap();
1337
1338        assert_eq!(prediction.edits.len(), 1);
1339        assert_eq!(
1340            prediction.edits[0].0.to_point(&snapshot).start,
1341            language::Point::new(1, 3)
1342        );
1343        assert_eq!(prediction.edits[0].1, " are you?");
1344    }
1345
1346    #[gpui::test]
1347    async fn test_request_diagnostics(cx: &mut TestAppContext) {
1348        let (zeta, mut req_rx) = init_test(cx);
1349        let fs = FakeFs::new(cx.executor());
1350        fs.insert_tree(
1351            "/root",
1352            json!({
1353                "foo.md": "Hello!\nBye"
1354            }),
1355        )
1356        .await;
1357        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1358
1359        let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1360        let diagnostic = lsp::Diagnostic {
1361            range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1362            severity: Some(lsp::DiagnosticSeverity::ERROR),
1363            message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1364            ..Default::default()
1365        };
1366
1367        project.update(cx, |project, cx| {
1368            project.lsp_store().update(cx, |lsp_store, cx| {
1369                // Create some diagnostics
1370                lsp_store
1371                    .update_diagnostics(
1372                        LanguageServerId(0),
1373                        lsp::PublishDiagnosticsParams {
1374                            uri: path_to_buffer_uri.clone(),
1375                            diagnostics: vec![diagnostic],
1376                            version: None,
1377                        },
1378                        None,
1379                        language::DiagnosticSourceKind::Pushed,
1380                        &[],
1381                        cx,
1382                    )
1383                    .unwrap();
1384            });
1385        });
1386
1387        let buffer = project
1388            .update(cx, |project, cx| {
1389                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1390                project.open_buffer(path, cx)
1391            })
1392            .await
1393            .unwrap();
1394
1395        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1396        let position = snapshot.anchor_before(language::Point::new(0, 0));
1397
1398        let _prediction_task = zeta.update(cx, |zeta, cx| {
1399            zeta.request_prediction(&project, &buffer, position, cx)
1400        });
1401
1402        let (request, _respond_tx) = req_rx.next().await.unwrap();
1403
1404        assert_eq!(request.diagnostic_groups.len(), 1);
1405        let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1406            .unwrap();
1407        // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1408        assert_eq!(
1409            value,
1410            json!({
1411                "entries": [{
1412                    "range": {
1413                        "start": 8,
1414                        "end": 10
1415                    },
1416                    "diagnostic": {
1417                        "source": null,
1418                        "code": null,
1419                        "code_description": null,
1420                        "severity": 1,
1421                        "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1422                        "markdown": null,
1423                        "group_id": 0,
1424                        "is_primary": true,
1425                        "is_disk_based": false,
1426                        "is_unnecessary": false,
1427                        "source_kind": "Pushed",
1428                        "data": null,
1429                        "underline": true
1430                    }
1431                }],
1432                "primary_ix": 0
1433            })
1434        );
1435    }
1436
1437    fn init_test(
1438        cx: &mut TestAppContext,
1439    ) -> (
1440        Entity<Zeta>,
1441        mpsc::UnboundedReceiver<(
1442            predict_edits_v3::PredictEditsRequest,
1443            oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1444        )>,
1445    ) {
1446        cx.update(move |cx| {
1447            let settings_store = SettingsStore::test(cx);
1448            cx.set_global(settings_store);
1449            language::init(cx);
1450            Project::init_settings(cx);
1451
1452            let (req_tx, req_rx) = mpsc::unbounded();
1453
1454            let http_client = FakeHttpClient::create({
1455                move |req| {
1456                    let uri = req.uri().path().to_string();
1457                    let mut body = req.into_body();
1458                    let req_tx = req_tx.clone();
1459                    async move {
1460                        let resp = match uri.as_str() {
1461                            "/client/llm_tokens" => serde_json::to_string(&json!({
1462                                "token": "test"
1463                            }))
1464                            .unwrap(),
1465                            "/predict_edits/v3" => {
1466                                let mut buf = Vec::new();
1467                                body.read_to_end(&mut buf).await.ok();
1468                                let req = serde_json::from_slice(&buf).unwrap();
1469
1470                                let (res_tx, res_rx) = oneshot::channel();
1471                                req_tx.unbounded_send((req, res_tx)).unwrap();
1472                                serde_json::to_string(&res_rx.await?).unwrap()
1473                            }
1474                            _ => {
1475                                panic!("Unexpected path: {}", uri)
1476                            }
1477                        };
1478
1479                        Ok(Response::builder().body(resp.into()).unwrap())
1480                    }
1481                }
1482            });
1483
1484            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1485            client.cloud_client().set_credentials(1, "test".into());
1486
1487            language_model::init(client.clone(), cx);
1488
1489            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1490            let zeta = Zeta::global(&client, &user_store, cx);
1491
1492            (zeta, req_rx)
1493        })
1494    }
1495}