zeta2.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use chrono::TimeDelta;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
   7};
   8use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
   9use edit_prediction_context::{
  10    DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
  11    SyntaxIndexState,
  12};
  13use futures::AsyncReadExt as _;
  14use futures::channel::mpsc;
  15use gpui::http_client::Method;
  16use gpui::{
  17    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  18    http_client, prelude::*,
  19};
  20use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
  21use language::{BufferSnapshot, TextBufferSnapshot};
  22use language_model::{LlmApiToken, RefreshLlmTokenListener};
  23use project::Project;
  24use release_channel::AppVersion;
  25use std::collections::{HashMap, VecDeque, hash_map};
  26use std::path::Path;
  27use std::str::FromStr as _;
  28use std::sync::Arc;
  29use std::time::{Duration, Instant};
  30use thiserror::Error;
  31use util::rel_path::RelPathBuf;
  32use util::some_or_debug_panic;
  33use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  34
  35mod prediction;
  36mod provider;
  37
  38use crate::prediction::EditPrediction;
  39pub use provider::ZetaEditPredictionProvider;
  40
  41const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  42
  43/// Maximum number of events to track.
  44const MAX_EVENT_COUNT: usize = 16;
  45
  46pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
  47    max_bytes: 512,
  48    min_bytes: 128,
  49    target_before_cursor_over_total_bytes: 0.5,
  50};
  51
  52pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  53    excerpt: DEFAULT_EXCERPT_OPTIONS,
  54    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
  55    max_diagnostic_bytes: 2048,
  56    prompt_format: PromptFormat::DEFAULT,
  57};
  58
  59#[derive(Clone)]
  60struct ZetaGlobal(Entity<Zeta>);
  61
  62impl Global for ZetaGlobal {}
  63
  64pub struct Zeta {
  65    client: Arc<Client>,
  66    user_store: Entity<UserStore>,
  67    llm_token: LlmApiToken,
  68    _llm_token_subscription: Subscription,
  69    projects: HashMap<EntityId, ZetaProject>,
  70    options: ZetaOptions,
  71    update_required: bool,
  72    debug_tx: Option<mpsc::UnboundedSender<Result<PredictionDebugInfo, String>>>,
  73}
  74
  75#[derive(Debug, Clone, PartialEq)]
  76pub struct ZetaOptions {
  77    pub excerpt: EditPredictionExcerptOptions,
  78    pub max_prompt_bytes: usize,
  79    pub max_diagnostic_bytes: usize,
  80    pub prompt_format: predict_edits_v3::PromptFormat,
  81}
  82
  83pub struct PredictionDebugInfo {
  84    pub context: EditPredictionContext,
  85    pub retrieval_time: TimeDelta,
  86    pub request: RequestDebugInfo,
  87    pub buffer: WeakEntity<Buffer>,
  88    pub position: language::Anchor,
  89}
  90
  91pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
  92
  93struct ZetaProject {
  94    syntax_index: Entity<SyntaxIndex>,
  95    events: VecDeque<Event>,
  96    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
  97    current_prediction: Option<CurrentEditPrediction>,
  98}
  99
 100#[derive(Clone)]
 101struct CurrentEditPrediction {
 102    pub requested_by_buffer_id: EntityId,
 103    pub prediction: EditPrediction,
 104}
 105
 106impl CurrentEditPrediction {
 107    fn should_replace_prediction(
 108        &self,
 109        old_prediction: &Self,
 110        snapshot: &TextBufferSnapshot,
 111    ) -> bool {
 112        if self.requested_by_buffer_id != old_prediction.requested_by_buffer_id {
 113            return true;
 114        }
 115
 116        let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
 117            return true;
 118        };
 119
 120        let Some(new_edits) = self.prediction.interpolate(snapshot) else {
 121            return false;
 122        };
 123        if old_edits.len() == 1 && new_edits.len() == 1 {
 124            let (old_range, old_text) = &old_edits[0];
 125            let (new_range, new_text) = &new_edits[0];
 126            new_range == old_range && new_text.starts_with(old_text)
 127        } else {
 128            true
 129        }
 130    }
 131}
 132
 133/// A prediction from the perspective of a buffer.
 134#[derive(Debug)]
 135enum BufferEditPrediction<'a> {
 136    Local { prediction: &'a EditPrediction },
 137    Jump { prediction: &'a EditPrediction },
 138}
 139
 140struct RegisteredBuffer {
 141    snapshot: BufferSnapshot,
 142    _subscriptions: [gpui::Subscription; 2],
 143}
 144
 145#[derive(Clone)]
 146pub enum Event {
 147    BufferChange {
 148        old_snapshot: BufferSnapshot,
 149        new_snapshot: BufferSnapshot,
 150        timestamp: Instant,
 151    },
 152}
 153
 154impl Zeta {
 155    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 156        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 157    }
 158
 159    pub fn global(
 160        client: &Arc<Client>,
 161        user_store: &Entity<UserStore>,
 162        cx: &mut App,
 163    ) -> Entity<Self> {
 164        cx.try_global::<ZetaGlobal>()
 165            .map(|global| global.0.clone())
 166            .unwrap_or_else(|| {
 167                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 168                cx.set_global(ZetaGlobal(zeta.clone()));
 169                zeta
 170            })
 171    }
 172
 173    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 174        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 175
 176        Self {
 177            projects: HashMap::new(),
 178            client,
 179            user_store,
 180            options: DEFAULT_OPTIONS,
 181            llm_token: LlmApiToken::default(),
 182            _llm_token_subscription: cx.subscribe(
 183                &refresh_llm_token_listener,
 184                |this, _listener, _event, cx| {
 185                    let client = this.client.clone();
 186                    let llm_token = this.llm_token.clone();
 187                    cx.spawn(async move |_this, _cx| {
 188                        llm_token.refresh(&client).await?;
 189                        anyhow::Ok(())
 190                    })
 191                    .detach_and_log_err(cx);
 192                },
 193            ),
 194            update_required: false,
 195            debug_tx: None,
 196        }
 197    }
 198
 199    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<Result<PredictionDebugInfo, String>> {
 200        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 201        self.debug_tx = Some(debug_watch_tx);
 202        debug_watch_rx
 203    }
 204
 205    pub fn options(&self) -> &ZetaOptions {
 206        &self.options
 207    }
 208
 209    pub fn set_options(&mut self, options: ZetaOptions) {
 210        self.options = options;
 211    }
 212
 213    pub fn clear_history(&mut self) {
 214        for zeta_project in self.projects.values_mut() {
 215            zeta_project.events.clear();
 216        }
 217    }
 218
 219    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 220        self.user_store.read(cx).edit_prediction_usage()
 221    }
 222
 223    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 224        self.get_or_init_zeta_project(project, cx);
 225    }
 226
 227    pub fn register_buffer(
 228        &mut self,
 229        buffer: &Entity<Buffer>,
 230        project: &Entity<Project>,
 231        cx: &mut Context<Self>,
 232    ) {
 233        let zeta_project = self.get_or_init_zeta_project(project, cx);
 234        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 235    }
 236
 237    fn get_or_init_zeta_project(
 238        &mut self,
 239        project: &Entity<Project>,
 240        cx: &mut App,
 241    ) -> &mut ZetaProject {
 242        self.projects
 243            .entry(project.entity_id())
 244            .or_insert_with(|| ZetaProject {
 245                syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
 246                events: VecDeque::new(),
 247                registered_buffers: HashMap::new(),
 248                current_prediction: None,
 249            })
 250    }
 251
 252    fn register_buffer_impl<'a>(
 253        zeta_project: &'a mut ZetaProject,
 254        buffer: &Entity<Buffer>,
 255        project: &Entity<Project>,
 256        cx: &mut Context<Self>,
 257    ) -> &'a mut RegisteredBuffer {
 258        let buffer_id = buffer.entity_id();
 259        match zeta_project.registered_buffers.entry(buffer_id) {
 260            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 261            hash_map::Entry::Vacant(entry) => {
 262                let snapshot = buffer.read(cx).snapshot();
 263                let project_entity_id = project.entity_id();
 264                entry.insert(RegisteredBuffer {
 265                    snapshot,
 266                    _subscriptions: [
 267                        cx.subscribe(buffer, {
 268                            let project = project.downgrade();
 269                            move |this, buffer, event, cx| {
 270                                if let language::BufferEvent::Edited = event
 271                                    && let Some(project) = project.upgrade()
 272                                {
 273                                    this.report_changes_for_buffer(&buffer, &project, cx);
 274                                }
 275                            }
 276                        }),
 277                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 278                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 279                            else {
 280                                return;
 281                            };
 282                            zeta_project.registered_buffers.remove(&buffer_id);
 283                        }),
 284                    ],
 285                })
 286            }
 287        }
 288    }
 289
 290    fn report_changes_for_buffer(
 291        &mut self,
 292        buffer: &Entity<Buffer>,
 293        project: &Entity<Project>,
 294        cx: &mut Context<Self>,
 295    ) -> BufferSnapshot {
 296        let zeta_project = self.get_or_init_zeta_project(project, cx);
 297        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 298
 299        let new_snapshot = buffer.read(cx).snapshot();
 300        if new_snapshot.version != registered_buffer.snapshot.version {
 301            let old_snapshot =
 302                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 303            Self::push_event(
 304                zeta_project,
 305                Event::BufferChange {
 306                    old_snapshot,
 307                    new_snapshot: new_snapshot.clone(),
 308                    timestamp: Instant::now(),
 309                },
 310            );
 311        }
 312
 313        new_snapshot
 314    }
 315
 316    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
 317        let events = &mut zeta_project.events;
 318
 319        if let Some(Event::BufferChange {
 320            new_snapshot: last_new_snapshot,
 321            timestamp: last_timestamp,
 322            ..
 323        }) = events.back_mut()
 324        {
 325            // Coalesce edits for the same buffer when they happen one after the other.
 326            let Event::BufferChange {
 327                old_snapshot,
 328                new_snapshot,
 329                timestamp,
 330            } = &event;
 331
 332            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 333                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 334                && old_snapshot.version == last_new_snapshot.version
 335            {
 336                *last_new_snapshot = new_snapshot.clone();
 337                *last_timestamp = *timestamp;
 338                return;
 339            }
 340        }
 341
 342        if events.len() >= MAX_EVENT_COUNT {
 343            // These are halved instead of popping to improve prompt caching.
 344            events.drain(..MAX_EVENT_COUNT / 2);
 345        }
 346
 347        events.push_back(event);
 348    }
 349
 350    fn current_prediction_for_buffer(
 351        &self,
 352        buffer: &Entity<Buffer>,
 353        project: &Entity<Project>,
 354        cx: &App,
 355    ) -> Option<BufferEditPrediction<'_>> {
 356        let project_state = self.projects.get(&project.entity_id())?;
 357
 358        let CurrentEditPrediction {
 359            requested_by_buffer_id,
 360            prediction,
 361        } = project_state.current_prediction.as_ref()?;
 362
 363        if prediction.targets_buffer(buffer.read(cx), cx) {
 364            Some(BufferEditPrediction::Local { prediction })
 365        } else if *requested_by_buffer_id == buffer.entity_id() {
 366            Some(BufferEditPrediction::Jump { prediction })
 367        } else {
 368            None
 369        }
 370    }
 371
 372    fn accept_current_prediction(&mut self, project: &Entity<Project>) {
 373        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 374            project_state.current_prediction.take();
 375        };
 376        // TODO report accepted
 377    }
 378
 379    fn discard_current_prediction(&mut self, project: &Entity<Project>) {
 380        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 381            project_state.current_prediction.take();
 382        };
 383    }
 384
 385    pub fn refresh_prediction(
 386        &mut self,
 387        project: &Entity<Project>,
 388        buffer: &Entity<Buffer>,
 389        position: language::Anchor,
 390        cx: &mut Context<Self>,
 391    ) -> Task<Result<()>> {
 392        let request_task = self.request_prediction(project, buffer, position, cx);
 393        let buffer = buffer.clone();
 394        let project = project.clone();
 395
 396        cx.spawn(async move |this, cx| {
 397            if let Some(prediction) = request_task.await? {
 398                this.update(cx, |this, cx| {
 399                    let project_state = this
 400                        .projects
 401                        .get_mut(&project.entity_id())
 402                        .context("Project not found")?;
 403
 404                    let new_prediction = CurrentEditPrediction {
 405                        requested_by_buffer_id: buffer.entity_id(),
 406                        prediction: prediction,
 407                    };
 408
 409                    if project_state
 410                        .current_prediction
 411                        .as_ref()
 412                        .is_none_or(|old_prediction| {
 413                            new_prediction
 414                                .should_replace_prediction(&old_prediction, buffer.read(cx))
 415                        })
 416                    {
 417                        project_state.current_prediction = Some(new_prediction);
 418                    }
 419                    anyhow::Ok(())
 420                })??;
 421            }
 422            Ok(())
 423        })
 424    }
 425
 426    fn request_prediction(
 427        &mut self,
 428        project: &Entity<Project>,
 429        buffer: &Entity<Buffer>,
 430        position: language::Anchor,
 431        cx: &mut Context<Self>,
 432    ) -> Task<Result<Option<EditPrediction>>> {
 433        let project_state = self.projects.get(&project.entity_id());
 434
 435        let index_state = project_state.map(|state| {
 436            state
 437                .syntax_index
 438                .read_with(cx, |index, _cx| index.state().clone())
 439        });
 440        let options = self.options.clone();
 441        let snapshot = buffer.read(cx).snapshot();
 442        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx).into()) else {
 443            return Task::ready(Err(anyhow!("No file path for excerpt")));
 444        };
 445        let client = self.client.clone();
 446        let llm_token = self.llm_token.clone();
 447        let app_version = AppVersion::global(cx);
 448        let worktree_snapshots = project
 449            .read(cx)
 450            .worktrees(cx)
 451            .map(|worktree| worktree.read(cx).snapshot())
 452            .collect::<Vec<_>>();
 453        let debug_tx = self.debug_tx.clone();
 454
 455        let events = project_state
 456            .map(|state| {
 457                state
 458                    .events
 459                    .iter()
 460                    .filter_map(|event| match event {
 461                        Event::BufferChange {
 462                            old_snapshot,
 463                            new_snapshot,
 464                            ..
 465                        } => {
 466                            let path = new_snapshot.file().map(|f| f.full_path(cx));
 467
 468                            let old_path = old_snapshot.file().and_then(|f| {
 469                                let old_path = f.full_path(cx);
 470                                if Some(&old_path) != path.as_ref() {
 471                                    Some(old_path)
 472                                } else {
 473                                    None
 474                                }
 475                            });
 476
 477                            // TODO [zeta2] move to bg?
 478                            let diff =
 479                                language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
 480
 481                            if path == old_path && diff.is_empty() {
 482                                None
 483                            } else {
 484                                Some(predict_edits_v3::Event::BufferChange {
 485                                    old_path,
 486                                    path,
 487                                    diff,
 488                                    //todo: Actually detect if this edit was predicted or not
 489                                    predicted: false,
 490                                })
 491                            }
 492                        }
 493                    })
 494                    .collect::<Vec<_>>()
 495            })
 496            .unwrap_or_default();
 497
 498        let diagnostics = snapshot.diagnostic_sets().clone();
 499
 500        let request_task = cx.background_spawn({
 501            let snapshot = snapshot.clone();
 502            let buffer = buffer.clone();
 503            async move {
 504                let index_state = if let Some(index_state) = index_state {
 505                    Some(index_state.lock_owned().await)
 506                } else {
 507                    None
 508                };
 509
 510                let cursor_offset = position.to_offset(&snapshot);
 511                let cursor_point = cursor_offset.to_point(&snapshot);
 512
 513                let before_retrieval = chrono::Utc::now();
 514
 515                let Some(context) = EditPredictionContext::gather_context(
 516                    cursor_point,
 517                    &snapshot,
 518                    &options.excerpt,
 519                    index_state.as_deref(),
 520                ) else {
 521                    return Ok(None);
 522                };
 523
 524                let debug_context = if let Some(debug_tx) = debug_tx {
 525                    Some((debug_tx, context.clone()))
 526                } else {
 527                    None
 528                };
 529
 530                let (diagnostic_groups, diagnostic_groups_truncated) =
 531                    Self::gather_nearby_diagnostics(
 532                        cursor_offset,
 533                        &diagnostics,
 534                        &snapshot,
 535                        options.max_diagnostic_bytes,
 536                    );
 537
 538                let request = make_cloud_request(
 539                    excerpt_path,
 540                    context,
 541                    events,
 542                    // TODO data collection
 543                    false,
 544                    diagnostic_groups,
 545                    diagnostic_groups_truncated,
 546                    None,
 547                    debug_context.is_some(),
 548                    &worktree_snapshots,
 549                    index_state.as_deref(),
 550                    Some(options.max_prompt_bytes),
 551                    options.prompt_format,
 552                );
 553
 554                let retrieval_time = chrono::Utc::now() - before_retrieval;
 555                let response = Self::perform_request(client, llm_token, app_version, request).await;
 556
 557                if let Some((debug_tx, context)) = debug_context {
 558                    debug_tx
 559                        .unbounded_send(response.as_ref().map_err(|err| err.to_string()).and_then(
 560                            |response| {
 561                                let Some(request) =
 562                                    some_or_debug_panic(response.0.debug_info.clone())
 563                                else {
 564                                    return Err("Missing debug info".to_string());
 565                                };
 566                                Ok(PredictionDebugInfo {
 567                                    context,
 568                                    request,
 569                                    retrieval_time,
 570                                    buffer: buffer.downgrade(),
 571                                    position,
 572                                })
 573                            },
 574                        ))
 575                        .ok();
 576                }
 577
 578                anyhow::Ok(Some(response?))
 579            }
 580        });
 581
 582        let buffer = buffer.clone();
 583
 584        cx.spawn({
 585            let project = project.clone();
 586            async move |this, cx| {
 587                match request_task.await {
 588                    Ok(Some((response, usage))) => {
 589                        if let Some(usage) = usage {
 590                            this.update(cx, |this, cx| {
 591                                this.user_store.update(cx, |user_store, cx| {
 592                                    user_store.update_edit_prediction_usage(usage, cx);
 593                                });
 594                            })
 595                            .ok();
 596                        }
 597
 598                        let prediction = EditPrediction::from_response(
 599                            response, &snapshot, &buffer, &project, cx,
 600                        )
 601                        .await;
 602
 603                        // TODO telemetry: duration, etc
 604                        Ok(prediction)
 605                    }
 606                    Ok(None) => Ok(None),
 607                    Err(err) => {
 608                        if err.is::<ZedUpdateRequiredError>() {
 609                            cx.update(|cx| {
 610                                this.update(cx, |this, _cx| {
 611                                    this.update_required = true;
 612                                })
 613                                .ok();
 614
 615                                let error_message: SharedString = err.to_string().into();
 616                                show_app_notification(
 617                                    NotificationId::unique::<ZedUpdateRequiredError>(),
 618                                    cx,
 619                                    move |cx| {
 620                                        cx.new(|cx| {
 621                                            ErrorMessagePrompt::new(error_message.clone(), cx)
 622                                                .with_link_button(
 623                                                    "Update Zed",
 624                                                    "https://zed.dev/releases",
 625                                                )
 626                                        })
 627                                    },
 628                                );
 629                            })
 630                            .ok();
 631                        }
 632
 633                        Err(err)
 634                    }
 635                }
 636            }
 637        })
 638    }
 639
 640    async fn perform_request(
 641        client: Arc<Client>,
 642        llm_token: LlmApiToken,
 643        app_version: SemanticVersion,
 644        request: predict_edits_v3::PredictEditsRequest,
 645    ) -> Result<(
 646        predict_edits_v3::PredictEditsResponse,
 647        Option<EditPredictionUsage>,
 648    )> {
 649        let http_client = client.http_client();
 650        let mut token = llm_token.acquire(&client).await?;
 651        let mut did_retry = false;
 652
 653        loop {
 654            let request_builder = http_client::Request::builder().method(Method::POST);
 655            let request_builder =
 656                if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 657                    request_builder.uri(predict_edits_url)
 658                } else {
 659                    request_builder.uri(
 660                        http_client
 661                            .build_zed_llm_url("/predict_edits/v3", &[])?
 662                            .as_ref(),
 663                    )
 664                };
 665            let request = request_builder
 666                .header("Content-Type", "application/json")
 667                .header("Authorization", format!("Bearer {}", token))
 668                .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 669                .body(serde_json::to_string(&request)?.into())?;
 670
 671            let mut response = http_client.send(request).await?;
 672
 673            if let Some(minimum_required_version) = response
 674                .headers()
 675                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 676                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 677            {
 678                anyhow::ensure!(
 679                    app_version >= minimum_required_version,
 680                    ZedUpdateRequiredError {
 681                        minimum_version: minimum_required_version
 682                    }
 683                );
 684            }
 685
 686            if response.status().is_success() {
 687                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 688
 689                let mut body = Vec::new();
 690                response.body_mut().read_to_end(&mut body).await?;
 691                return Ok((serde_json::from_slice(&body)?, usage));
 692            } else if !did_retry
 693                && response
 694                    .headers()
 695                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 696                    .is_some()
 697            {
 698                did_retry = true;
 699                token = llm_token.refresh(&client).await?;
 700            } else {
 701                let mut body = String::new();
 702                response.body_mut().read_to_string(&mut body).await?;
 703                anyhow::bail!(
 704                    "error predicting edits.\nStatus: {:?}\nBody: {}",
 705                    response.status(),
 706                    body
 707                );
 708            }
 709        }
 710    }
 711
 712    fn gather_nearby_diagnostics(
 713        cursor_offset: usize,
 714        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
 715        snapshot: &BufferSnapshot,
 716        max_diagnostics_bytes: usize,
 717    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
 718        // TODO: Could make this more efficient
 719        let mut diagnostic_groups = Vec::new();
 720        for (language_server_id, diagnostics) in diagnostic_sets {
 721            let mut groups = Vec::new();
 722            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
 723            diagnostic_groups.extend(
 724                groups
 725                    .into_iter()
 726                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
 727            );
 728        }
 729
 730        // sort by proximity to cursor
 731        diagnostic_groups.sort_by_key(|group| {
 732            let range = &group.entries[group.primary_ix].range;
 733            if range.start >= cursor_offset {
 734                range.start - cursor_offset
 735            } else if cursor_offset >= range.end {
 736                cursor_offset - range.end
 737            } else {
 738                (cursor_offset - range.start).min(range.end - cursor_offset)
 739            }
 740        });
 741
 742        let mut results = Vec::new();
 743        let mut diagnostic_groups_truncated = false;
 744        let mut diagnostics_byte_count = 0;
 745        for group in diagnostic_groups {
 746            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
 747            diagnostics_byte_count += raw_value.get().len();
 748            if diagnostics_byte_count > max_diagnostics_bytes {
 749                diagnostic_groups_truncated = true;
 750                break;
 751            }
 752            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
 753        }
 754
 755        (results, diagnostic_groups_truncated)
 756    }
 757
 758    // TODO: Dedupe with similar code in request_prediction?
 759    pub fn cloud_request_for_zeta_cli(
 760        &mut self,
 761        project: &Entity<Project>,
 762        buffer: &Entity<Buffer>,
 763        position: language::Anchor,
 764        cx: &mut Context<Self>,
 765    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
 766        let project_state = self.projects.get(&project.entity_id());
 767
 768        let index_state = project_state.map(|state| {
 769            state
 770                .syntax_index
 771                .read_with(cx, |index, _cx| index.state().clone())
 772        });
 773        let options = self.options.clone();
 774        let snapshot = buffer.read(cx).snapshot();
 775        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
 776            return Task::ready(Err(anyhow!("No file path for excerpt")));
 777        };
 778        let worktree_snapshots = project
 779            .read(cx)
 780            .worktrees(cx)
 781            .map(|worktree| worktree.read(cx).snapshot())
 782            .collect::<Vec<_>>();
 783
 784        cx.background_spawn(async move {
 785            let index_state = if let Some(index_state) = index_state {
 786                Some(index_state.lock_owned().await)
 787            } else {
 788                None
 789            };
 790
 791            let cursor_point = position.to_point(&snapshot);
 792
 793            let debug_info = true;
 794            EditPredictionContext::gather_context(
 795                cursor_point,
 796                &snapshot,
 797                &options.excerpt,
 798                index_state.as_deref(),
 799            )
 800            .context("Failed to select excerpt")
 801            .map(|context| {
 802                make_cloud_request(
 803                    excerpt_path.into(),
 804                    context,
 805                    // TODO pass everything
 806                    Vec::new(),
 807                    false,
 808                    Vec::new(),
 809                    false,
 810                    None,
 811                    debug_info,
 812                    &worktree_snapshots,
 813                    index_state.as_deref(),
 814                    Some(options.max_prompt_bytes),
 815                    options.prompt_format,
 816                )
 817            })
 818        })
 819    }
 820}
 821
 822#[derive(Error, Debug)]
 823#[error(
 824    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
 825)]
 826pub struct ZedUpdateRequiredError {
 827    minimum_version: SemanticVersion,
 828}
 829
 830fn make_cloud_request(
 831    excerpt_path: Arc<Path>,
 832    context: EditPredictionContext,
 833    events: Vec<predict_edits_v3::Event>,
 834    can_collect_data: bool,
 835    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
 836    diagnostic_groups_truncated: bool,
 837    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
 838    debug_info: bool,
 839    worktrees: &Vec<worktree::Snapshot>,
 840    index_state: Option<&SyntaxIndexState>,
 841    prompt_max_bytes: Option<usize>,
 842    prompt_format: PromptFormat,
 843) -> predict_edits_v3::PredictEditsRequest {
 844    let mut signatures = Vec::new();
 845    let mut declaration_to_signature_index = HashMap::default();
 846    let mut referenced_declarations = Vec::new();
 847
 848    for snippet in context.declarations {
 849        let project_entry_id = snippet.declaration.project_entry_id();
 850        let Some(path) = worktrees.iter().find_map(|worktree| {
 851            worktree.entry_for_id(project_entry_id).map(|entry| {
 852                let mut full_path = RelPathBuf::new();
 853                full_path.push(worktree.root_name());
 854                full_path.push(&entry.path);
 855                full_path
 856            })
 857        }) else {
 858            continue;
 859        };
 860
 861        let parent_index = index_state.and_then(|index_state| {
 862            snippet.declaration.parent().and_then(|parent| {
 863                add_signature(
 864                    parent,
 865                    &mut declaration_to_signature_index,
 866                    &mut signatures,
 867                    index_state,
 868                )
 869            })
 870        });
 871
 872        let (text, text_is_truncated) = snippet.declaration.item_text();
 873        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
 874            path: path.as_std_path().into(),
 875            text: text.into(),
 876            range: snippet.declaration.item_range(),
 877            text_is_truncated,
 878            signature_range: snippet.declaration.signature_range_in_item_text(),
 879            parent_index,
 880            score_components: snippet.score_components,
 881            signature_score: snippet.scores.signature,
 882            declaration_score: snippet.scores.declaration,
 883        });
 884    }
 885
 886    let excerpt_parent = index_state.and_then(|index_state| {
 887        context
 888            .excerpt
 889            .parent_declarations
 890            .last()
 891            .and_then(|(parent, _)| {
 892                add_signature(
 893                    *parent,
 894                    &mut declaration_to_signature_index,
 895                    &mut signatures,
 896                    index_state,
 897                )
 898            })
 899    });
 900
 901    predict_edits_v3::PredictEditsRequest {
 902        excerpt_path,
 903        excerpt: context.excerpt_text.body,
 904        excerpt_range: context.excerpt.range,
 905        cursor_offset: context.cursor_offset_in_excerpt,
 906        referenced_declarations,
 907        signatures,
 908        excerpt_parent,
 909        events,
 910        can_collect_data,
 911        diagnostic_groups,
 912        diagnostic_groups_truncated,
 913        git_info,
 914        debug_info,
 915        prompt_max_bytes,
 916        prompt_format,
 917    }
 918}
 919
 920fn add_signature(
 921    declaration_id: DeclarationId,
 922    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
 923    signatures: &mut Vec<Signature>,
 924    index: &SyntaxIndexState,
 925) -> Option<usize> {
 926    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
 927        return Some(*signature_index);
 928    }
 929    let Some(parent_declaration) = index.declaration(declaration_id) else {
 930        log::error!("bug: missing parent declaration");
 931        return None;
 932    };
 933    let parent_index = parent_declaration.parent().and_then(|parent| {
 934        add_signature(parent, declaration_to_signature_index, signatures, index)
 935    });
 936    let (text, text_is_truncated) = parent_declaration.signature_text();
 937    let signature_index = signatures.len();
 938    signatures.push(Signature {
 939        text: text.into(),
 940        text_is_truncated,
 941        parent_index,
 942        range: parent_declaration.signature_range(),
 943    });
 944    declaration_to_signature_index.insert(declaration_id, signature_index);
 945    Some(signature_index)
 946}
 947
 948#[cfg(test)]
 949mod tests {
 950    use std::{
 951        path::{Path, PathBuf},
 952        sync::Arc,
 953    };
 954
 955    use client::UserStore;
 956    use clock::FakeSystemClock;
 957    use cloud_llm_client::predict_edits_v3;
 958    use futures::{
 959        AsyncReadExt, StreamExt,
 960        channel::{mpsc, oneshot},
 961    };
 962    use gpui::{
 963        Entity, TestAppContext,
 964        http_client::{FakeHttpClient, Response},
 965        prelude::*,
 966    };
 967    use indoc::indoc;
 968    use language::{LanguageServerId, OffsetRangeExt as _};
 969    use pretty_assertions::{assert_eq, assert_matches};
 970    use project::{FakeFs, Project};
 971    use serde_json::json;
 972    use settings::SettingsStore;
 973    use util::path;
 974    use uuid::Uuid;
 975
 976    use crate::{BufferEditPrediction, Zeta};
 977
 978    #[gpui::test]
 979    async fn test_current_state(cx: &mut TestAppContext) {
 980        let (zeta, mut req_rx) = init_test(cx);
 981        let fs = FakeFs::new(cx.executor());
 982        fs.insert_tree(
 983            "/root",
 984            json!({
 985                "1.txt": "Hello!\nHow\nBye",
 986                "2.txt": "Hola!\nComo\nAdios"
 987            }),
 988        )
 989        .await;
 990        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
 991
 992        zeta.update(cx, |zeta, cx| {
 993            zeta.register_project(&project, cx);
 994        });
 995
 996        let buffer1 = project
 997            .update(cx, |project, cx| {
 998                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
 999                project.open_buffer(path, cx)
1000            })
1001            .await
1002            .unwrap();
1003        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1004        let position = snapshot1.anchor_before(language::Point::new(1, 3));
1005
1006        // Prediction for current file
1007
1008        let prediction_task = zeta.update(cx, |zeta, cx| {
1009            zeta.refresh_prediction(&project, &buffer1, position, cx)
1010        });
1011        let (_request, respond_tx) = req_rx.next().await.unwrap();
1012        respond_tx
1013            .send(predict_edits_v3::PredictEditsResponse {
1014                request_id: Uuid::new_v4(),
1015                edits: vec![predict_edits_v3::Edit {
1016                    path: Path::new(path!("root/1.txt")).into(),
1017                    range: 0..snapshot1.len(),
1018                    content: "Hello!\nHow are you?\nBye".into(),
1019                }],
1020                debug_info: None,
1021            })
1022            .unwrap();
1023        prediction_task.await.unwrap();
1024
1025        zeta.read_with(cx, |zeta, cx| {
1026            let prediction = zeta
1027                .current_prediction_for_buffer(&buffer1, &project, cx)
1028                .unwrap();
1029            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1030        });
1031
1032        // Prediction for another file
1033
1034        let prediction_task = zeta.update(cx, |zeta, cx| {
1035            zeta.refresh_prediction(&project, &buffer1, position, cx)
1036        });
1037        let (_request, respond_tx) = req_rx.next().await.unwrap();
1038        respond_tx
1039            .send(predict_edits_v3::PredictEditsResponse {
1040                request_id: Uuid::new_v4(),
1041                edits: vec![predict_edits_v3::Edit {
1042                    path: Path::new(path!("root/2.txt")).into(),
1043                    range: 0..snapshot1.len(),
1044                    content: "Hola!\nComo estas?\nAdios".into(),
1045                }],
1046                debug_info: None,
1047            })
1048            .unwrap();
1049        prediction_task.await.unwrap();
1050
1051        zeta.read_with(cx, |zeta, cx| {
1052            let prediction = zeta
1053                .current_prediction_for_buffer(&buffer1, &project, cx)
1054                .unwrap();
1055            assert_matches!(
1056                prediction,
1057                BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1058            );
1059        });
1060
1061        let buffer2 = project
1062            .update(cx, |project, cx| {
1063                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1064                project.open_buffer(path, cx)
1065            })
1066            .await
1067            .unwrap();
1068
1069        zeta.read_with(cx, |zeta, cx| {
1070            let prediction = zeta
1071                .current_prediction_for_buffer(&buffer2, &project, cx)
1072                .unwrap();
1073            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1074        });
1075    }
1076
1077    #[gpui::test]
1078    async fn test_simple_request(cx: &mut TestAppContext) {
1079        let (zeta, mut req_rx) = init_test(cx);
1080        let fs = FakeFs::new(cx.executor());
1081        fs.insert_tree(
1082            "/root",
1083            json!({
1084                "foo.md":  "Hello!\nHow\nBye"
1085            }),
1086        )
1087        .await;
1088        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1089
1090        let buffer = project
1091            .update(cx, |project, cx| {
1092                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1093                project.open_buffer(path, cx)
1094            })
1095            .await
1096            .unwrap();
1097        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1098        let position = snapshot.anchor_before(language::Point::new(1, 3));
1099
1100        let prediction_task = zeta.update(cx, |zeta, cx| {
1101            zeta.request_prediction(&project, &buffer, position, cx)
1102        });
1103
1104        let (request, respond_tx) = req_rx.next().await.unwrap();
1105        assert_eq!(
1106            request.excerpt_path.as_ref(),
1107            Path::new(path!("root/foo.md"))
1108        );
1109        assert_eq!(request.cursor_offset, 10);
1110
1111        respond_tx
1112            .send(predict_edits_v3::PredictEditsResponse {
1113                request_id: Uuid::new_v4(),
1114                edits: vec![predict_edits_v3::Edit {
1115                    path: Path::new(path!("root/foo.md")).into(),
1116                    range: 0..snapshot.len(),
1117                    content: "Hello!\nHow are you?\nBye".into(),
1118                }],
1119                debug_info: None,
1120            })
1121            .unwrap();
1122
1123        let prediction = prediction_task.await.unwrap().unwrap();
1124
1125        assert_eq!(prediction.edits.len(), 1);
1126        assert_eq!(
1127            prediction.edits[0].0.to_point(&snapshot).start,
1128            language::Point::new(1, 3)
1129        );
1130        assert_eq!(prediction.edits[0].1, " are you?");
1131    }
1132
1133    #[gpui::test]
1134    async fn test_request_events(cx: &mut TestAppContext) {
1135        let (zeta, mut req_rx) = init_test(cx);
1136        let fs = FakeFs::new(cx.executor());
1137        fs.insert_tree(
1138            "/root",
1139            json!({
1140                "foo.md": "Hello!\n\nBye"
1141            }),
1142        )
1143        .await;
1144        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1145
1146        let buffer = project
1147            .update(cx, |project, cx| {
1148                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1149                project.open_buffer(path, cx)
1150            })
1151            .await
1152            .unwrap();
1153
1154        zeta.update(cx, |zeta, cx| {
1155            zeta.register_buffer(&buffer, &project, cx);
1156        });
1157
1158        buffer.update(cx, |buffer, cx| {
1159            buffer.edit(vec![(7..7, "How")], None, cx);
1160        });
1161
1162        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1163        let position = snapshot.anchor_before(language::Point::new(1, 3));
1164
1165        let prediction_task = zeta.update(cx, |zeta, cx| {
1166            zeta.request_prediction(&project, &buffer, position, cx)
1167        });
1168
1169        let (request, respond_tx) = req_rx.next().await.unwrap();
1170
1171        assert_eq!(request.events.len(), 1);
1172        assert_eq!(
1173            request.events[0],
1174            predict_edits_v3::Event::BufferChange {
1175                path: Some(PathBuf::from(path!("root/foo.md"))),
1176                old_path: None,
1177                diff: indoc! {"
1178                        @@ -1,3 +1,3 @@
1179                         Hello!
1180                        -
1181                        +How
1182                         Bye
1183                    "}
1184                .to_string(),
1185                predicted: false
1186            }
1187        );
1188
1189        respond_tx
1190            .send(predict_edits_v3::PredictEditsResponse {
1191                request_id: Uuid::new_v4(),
1192                edits: vec![predict_edits_v3::Edit {
1193                    path: Path::new(path!("root/foo.md")).into(),
1194                    range: 0..snapshot.len(),
1195                    content: "Hello!\nHow are you?\nBye".into(),
1196                }],
1197                debug_info: None,
1198            })
1199            .unwrap();
1200
1201        let prediction = prediction_task.await.unwrap().unwrap();
1202
1203        assert_eq!(prediction.edits.len(), 1);
1204        assert_eq!(
1205            prediction.edits[0].0.to_point(&snapshot).start,
1206            language::Point::new(1, 3)
1207        );
1208        assert_eq!(prediction.edits[0].1, " are you?");
1209    }
1210
1211    #[gpui::test]
1212    async fn test_request_diagnostics(cx: &mut TestAppContext) {
1213        let (zeta, mut req_rx) = init_test(cx);
1214        let fs = FakeFs::new(cx.executor());
1215        fs.insert_tree(
1216            "/root",
1217            json!({
1218                "foo.md": "Hello!\nBye"
1219            }),
1220        )
1221        .await;
1222        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1223
1224        let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1225        let diagnostic = lsp::Diagnostic {
1226            range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1227            severity: Some(lsp::DiagnosticSeverity::ERROR),
1228            message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1229            ..Default::default()
1230        };
1231
1232        project.update(cx, |project, cx| {
1233            project.lsp_store().update(cx, |lsp_store, cx| {
1234                // Create some diagnostics
1235                lsp_store
1236                    .update_diagnostics(
1237                        LanguageServerId(0),
1238                        lsp::PublishDiagnosticsParams {
1239                            uri: path_to_buffer_uri.clone(),
1240                            diagnostics: vec![diagnostic],
1241                            version: None,
1242                        },
1243                        None,
1244                        language::DiagnosticSourceKind::Pushed,
1245                        &[],
1246                        cx,
1247                    )
1248                    .unwrap();
1249            });
1250        });
1251
1252        let buffer = project
1253            .update(cx, |project, cx| {
1254                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1255                project.open_buffer(path, cx)
1256            })
1257            .await
1258            .unwrap();
1259
1260        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1261        let position = snapshot.anchor_before(language::Point::new(0, 0));
1262
1263        let _prediction_task = zeta.update(cx, |zeta, cx| {
1264            zeta.request_prediction(&project, &buffer, position, cx)
1265        });
1266
1267        let (request, _respond_tx) = req_rx.next().await.unwrap();
1268
1269        assert_eq!(request.diagnostic_groups.len(), 1);
1270        let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1271            .unwrap();
1272        // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1273        assert_eq!(
1274            value,
1275            json!({
1276                "entries": [{
1277                    "range": {
1278                        "start": 8,
1279                        "end": 10
1280                    },
1281                    "diagnostic": {
1282                        "source": null,
1283                        "code": null,
1284                        "code_description": null,
1285                        "severity": 1,
1286                        "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1287                        "markdown": null,
1288                        "group_id": 0,
1289                        "is_primary": true,
1290                        "is_disk_based": false,
1291                        "is_unnecessary": false,
1292                        "source_kind": "Pushed",
1293                        "data": null,
1294                        "underline": true
1295                    }
1296                }],
1297                "primary_ix": 0
1298            })
1299        );
1300    }
1301
1302    fn init_test(
1303        cx: &mut TestAppContext,
1304    ) -> (
1305        Entity<Zeta>,
1306        mpsc::UnboundedReceiver<(
1307            predict_edits_v3::PredictEditsRequest,
1308            oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1309        )>,
1310    ) {
1311        cx.update(move |cx| {
1312            let settings_store = SettingsStore::test(cx);
1313            cx.set_global(settings_store);
1314            language::init(cx);
1315            Project::init_settings(cx);
1316
1317            let (req_tx, req_rx) = mpsc::unbounded();
1318
1319            let http_client = FakeHttpClient::create({
1320                move |req| {
1321                    let uri = req.uri().path().to_string();
1322                    let mut body = req.into_body();
1323                    let req_tx = req_tx.clone();
1324                    async move {
1325                        let resp = match uri.as_str() {
1326                            "/client/llm_tokens" => serde_json::to_string(&json!({
1327                                "token": "test"
1328                            }))
1329                            .unwrap(),
1330                            "/predict_edits/v3" => {
1331                                let mut buf = Vec::new();
1332                                body.read_to_end(&mut buf).await.ok();
1333                                let req = serde_json::from_slice(&buf).unwrap();
1334
1335                                let (res_tx, res_rx) = oneshot::channel();
1336                                req_tx.unbounded_send((req, res_tx)).unwrap();
1337                                serde_json::to_string(&res_rx.await.unwrap()).unwrap()
1338                            }
1339                            _ => {
1340                                panic!("Unexpected path: {}", uri)
1341                            }
1342                        };
1343
1344                        Ok(Response::builder().body(resp.into()).unwrap())
1345                    }
1346                }
1347            });
1348
1349            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1350            client.cloud_client().set_credentials(1, "test".into());
1351
1352            language_model::init(client.clone(), cx);
1353
1354            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1355            let zeta = Zeta::global(&client, &user_store, cx);
1356
1357            (zeta, req_rx)
1358        })
1359    }
1360}