zeta2.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use chrono::TimeDelta;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, PromptFormat, Signature};
   5use cloud_llm_client::{
   6    EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
   7};
   8use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
   9use edit_prediction_context::{
  10    DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
  11    SyntaxIndexState,
  12};
  13use futures::AsyncReadExt as _;
  14use futures::channel::mpsc;
  15use gpui::http_client::Method;
  16use gpui::{
  17    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  18    http_client, prelude::*,
  19};
  20use language::{Buffer, DiagnosticSet, LanguageServerId, ToOffset as _, ToPoint};
  21use language::{BufferSnapshot, TextBufferSnapshot};
  22use language_model::{LlmApiToken, RefreshLlmTokenListener};
  23use project::Project;
  24use release_channel::AppVersion;
  25use std::collections::{HashMap, VecDeque, hash_map};
  26use std::path::Path;
  27use std::str::FromStr as _;
  28use std::sync::Arc;
  29use std::time::{Duration, Instant};
  30use thiserror::Error;
  31use util::rel_path::RelPathBuf;
  32use util::some_or_debug_panic;
  33use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  34
  35mod prediction;
  36mod provider;
  37
  38use crate::prediction::EditPrediction;
  39pub use provider::ZetaEditPredictionProvider;
  40
  41const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  42
  43/// Maximum number of events to track.
  44const MAX_EVENT_COUNT: usize = 16;
  45
  46pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
  47    max_bytes: 512,
  48    min_bytes: 128,
  49    target_before_cursor_over_total_bytes: 0.5,
  50};
  51
  52pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  53    excerpt: DEFAULT_EXCERPT_OPTIONS,
  54    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
  55    max_diagnostic_bytes: 2048,
  56    prompt_format: PromptFormat::DEFAULT,
  57    file_indexing_parallelism: 1,
  58};
  59
  60#[derive(Clone)]
  61struct ZetaGlobal(Entity<Zeta>);
  62
  63impl Global for ZetaGlobal {}
  64
  65pub struct Zeta {
  66    client: Arc<Client>,
  67    user_store: Entity<UserStore>,
  68    llm_token: LlmApiToken,
  69    _llm_token_subscription: Subscription,
  70    projects: HashMap<EntityId, ZetaProject>,
  71    options: ZetaOptions,
  72    update_required: bool,
  73    debug_tx: Option<mpsc::UnboundedSender<Result<PredictionDebugInfo, String>>>,
  74}
  75
  76#[derive(Debug, Clone, PartialEq)]
  77pub struct ZetaOptions {
  78    pub excerpt: EditPredictionExcerptOptions,
  79    pub max_prompt_bytes: usize,
  80    pub max_diagnostic_bytes: usize,
  81    pub prompt_format: predict_edits_v3::PromptFormat,
  82    pub file_indexing_parallelism: usize,
  83}
  84
  85pub struct PredictionDebugInfo {
  86    pub context: EditPredictionContext,
  87    pub retrieval_time: TimeDelta,
  88    pub request: RequestDebugInfo,
  89    pub buffer: WeakEntity<Buffer>,
  90    pub position: language::Anchor,
  91}
  92
  93pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
  94
  95struct ZetaProject {
  96    syntax_index: Entity<SyntaxIndex>,
  97    events: VecDeque<Event>,
  98    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
  99    current_prediction: Option<CurrentEditPrediction>,
 100}
 101
 102#[derive(Clone)]
 103struct CurrentEditPrediction {
 104    pub requested_by_buffer_id: EntityId,
 105    pub prediction: EditPrediction,
 106}
 107
 108impl CurrentEditPrediction {
 109    fn should_replace_prediction(
 110        &self,
 111        old_prediction: &Self,
 112        snapshot: &TextBufferSnapshot,
 113    ) -> bool {
 114        if self.requested_by_buffer_id != old_prediction.requested_by_buffer_id {
 115            return true;
 116        }
 117
 118        let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
 119            return true;
 120        };
 121
 122        let Some(new_edits) = self.prediction.interpolate(snapshot) else {
 123            return false;
 124        };
 125        if old_edits.len() == 1 && new_edits.len() == 1 {
 126            let (old_range, old_text) = &old_edits[0];
 127            let (new_range, new_text) = &new_edits[0];
 128            new_range == old_range && new_text.starts_with(old_text)
 129        } else {
 130            true
 131        }
 132    }
 133}
 134
 135/// A prediction from the perspective of a buffer.
 136#[derive(Debug)]
 137enum BufferEditPrediction<'a> {
 138    Local { prediction: &'a EditPrediction },
 139    Jump { prediction: &'a EditPrediction },
 140}
 141
 142struct RegisteredBuffer {
 143    snapshot: BufferSnapshot,
 144    _subscriptions: [gpui::Subscription; 2],
 145}
 146
 147#[derive(Clone)]
 148pub enum Event {
 149    BufferChange {
 150        old_snapshot: BufferSnapshot,
 151        new_snapshot: BufferSnapshot,
 152        timestamp: Instant,
 153    },
 154}
 155
 156impl Zeta {
 157    pub fn try_global(cx: &App) -> Option<Entity<Self>> {
 158        cx.try_global::<ZetaGlobal>().map(|global| global.0.clone())
 159    }
 160
 161    pub fn global(
 162        client: &Arc<Client>,
 163        user_store: &Entity<UserStore>,
 164        cx: &mut App,
 165    ) -> Entity<Self> {
 166        cx.try_global::<ZetaGlobal>()
 167            .map(|global| global.0.clone())
 168            .unwrap_or_else(|| {
 169                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 170                cx.set_global(ZetaGlobal(zeta.clone()));
 171                zeta
 172            })
 173    }
 174
 175    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 176        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 177
 178        Self {
 179            projects: HashMap::new(),
 180            client,
 181            user_store,
 182            options: DEFAULT_OPTIONS,
 183            llm_token: LlmApiToken::default(),
 184            _llm_token_subscription: cx.subscribe(
 185                &refresh_llm_token_listener,
 186                |this, _listener, _event, cx| {
 187                    let client = this.client.clone();
 188                    let llm_token = this.llm_token.clone();
 189                    cx.spawn(async move |_this, _cx| {
 190                        llm_token.refresh(&client).await?;
 191                        anyhow::Ok(())
 192                    })
 193                    .detach_and_log_err(cx);
 194                },
 195            ),
 196            update_required: false,
 197            debug_tx: None,
 198        }
 199    }
 200
 201    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<Result<PredictionDebugInfo, String>> {
 202        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 203        self.debug_tx = Some(debug_watch_tx);
 204        debug_watch_rx
 205    }
 206
 207    pub fn options(&self) -> &ZetaOptions {
 208        &self.options
 209    }
 210
 211    pub fn set_options(&mut self, options: ZetaOptions) {
 212        self.options = options;
 213    }
 214
 215    pub fn clear_history(&mut self) {
 216        for zeta_project in self.projects.values_mut() {
 217            zeta_project.events.clear();
 218        }
 219    }
 220
 221    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 222        self.user_store.read(cx).edit_prediction_usage()
 223    }
 224
 225    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 226        self.get_or_init_zeta_project(project, cx);
 227    }
 228
 229    pub fn register_buffer(
 230        &mut self,
 231        buffer: &Entity<Buffer>,
 232        project: &Entity<Project>,
 233        cx: &mut Context<Self>,
 234    ) {
 235        let zeta_project = self.get_or_init_zeta_project(project, cx);
 236        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 237    }
 238
 239    fn get_or_init_zeta_project(
 240        &mut self,
 241        project: &Entity<Project>,
 242        cx: &mut App,
 243    ) -> &mut ZetaProject {
 244        self.projects
 245            .entry(project.entity_id())
 246            .or_insert_with(|| ZetaProject {
 247                syntax_index: cx.new(|cx| {
 248                    SyntaxIndex::new(project, self.options.file_indexing_parallelism, cx)
 249                }),
 250                events: VecDeque::new(),
 251                registered_buffers: HashMap::new(),
 252                current_prediction: None,
 253            })
 254    }
 255
 256    fn register_buffer_impl<'a>(
 257        zeta_project: &'a mut ZetaProject,
 258        buffer: &Entity<Buffer>,
 259        project: &Entity<Project>,
 260        cx: &mut Context<Self>,
 261    ) -> &'a mut RegisteredBuffer {
 262        let buffer_id = buffer.entity_id();
 263        match zeta_project.registered_buffers.entry(buffer_id) {
 264            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 265            hash_map::Entry::Vacant(entry) => {
 266                let snapshot = buffer.read(cx).snapshot();
 267                let project_entity_id = project.entity_id();
 268                entry.insert(RegisteredBuffer {
 269                    snapshot,
 270                    _subscriptions: [
 271                        cx.subscribe(buffer, {
 272                            let project = project.downgrade();
 273                            move |this, buffer, event, cx| {
 274                                if let language::BufferEvent::Edited = event
 275                                    && let Some(project) = project.upgrade()
 276                                {
 277                                    this.report_changes_for_buffer(&buffer, &project, cx);
 278                                }
 279                            }
 280                        }),
 281                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 282                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 283                            else {
 284                                return;
 285                            };
 286                            zeta_project.registered_buffers.remove(&buffer_id);
 287                        }),
 288                    ],
 289                })
 290            }
 291        }
 292    }
 293
 294    fn report_changes_for_buffer(
 295        &mut self,
 296        buffer: &Entity<Buffer>,
 297        project: &Entity<Project>,
 298        cx: &mut Context<Self>,
 299    ) -> BufferSnapshot {
 300        let zeta_project = self.get_or_init_zeta_project(project, cx);
 301        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 302
 303        let new_snapshot = buffer.read(cx).snapshot();
 304        if new_snapshot.version != registered_buffer.snapshot.version {
 305            let old_snapshot =
 306                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 307            Self::push_event(
 308                zeta_project,
 309                Event::BufferChange {
 310                    old_snapshot,
 311                    new_snapshot: new_snapshot.clone(),
 312                    timestamp: Instant::now(),
 313                },
 314            );
 315        }
 316
 317        new_snapshot
 318    }
 319
 320    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
 321        let events = &mut zeta_project.events;
 322
 323        if let Some(Event::BufferChange {
 324            new_snapshot: last_new_snapshot,
 325            timestamp: last_timestamp,
 326            ..
 327        }) = events.back_mut()
 328        {
 329            // Coalesce edits for the same buffer when they happen one after the other.
 330            let Event::BufferChange {
 331                old_snapshot,
 332                new_snapshot,
 333                timestamp,
 334            } = &event;
 335
 336            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 337                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 338                && old_snapshot.version == last_new_snapshot.version
 339            {
 340                *last_new_snapshot = new_snapshot.clone();
 341                *last_timestamp = *timestamp;
 342                return;
 343            }
 344        }
 345
 346        if events.len() >= MAX_EVENT_COUNT {
 347            // These are halved instead of popping to improve prompt caching.
 348            events.drain(..MAX_EVENT_COUNT / 2);
 349        }
 350
 351        events.push_back(event);
 352    }
 353
 354    fn current_prediction_for_buffer(
 355        &self,
 356        buffer: &Entity<Buffer>,
 357        project: &Entity<Project>,
 358        cx: &App,
 359    ) -> Option<BufferEditPrediction<'_>> {
 360        let project_state = self.projects.get(&project.entity_id())?;
 361
 362        let CurrentEditPrediction {
 363            requested_by_buffer_id,
 364            prediction,
 365        } = project_state.current_prediction.as_ref()?;
 366
 367        if prediction.targets_buffer(buffer.read(cx), cx) {
 368            Some(BufferEditPrediction::Local { prediction })
 369        } else if *requested_by_buffer_id == buffer.entity_id() {
 370            Some(BufferEditPrediction::Jump { prediction })
 371        } else {
 372            None
 373        }
 374    }
 375
 376    fn accept_current_prediction(&mut self, project: &Entity<Project>) {
 377        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 378            project_state.current_prediction.take();
 379        };
 380        // TODO report accepted
 381    }
 382
 383    fn discard_current_prediction(&mut self, project: &Entity<Project>) {
 384        if let Some(project_state) = self.projects.get_mut(&project.entity_id()) {
 385            project_state.current_prediction.take();
 386        };
 387    }
 388
 389    pub fn refresh_prediction(
 390        &mut self,
 391        project: &Entity<Project>,
 392        buffer: &Entity<Buffer>,
 393        position: language::Anchor,
 394        cx: &mut Context<Self>,
 395    ) -> Task<Result<()>> {
 396        let request_task = self.request_prediction(project, buffer, position, cx);
 397        let buffer = buffer.clone();
 398        let project = project.clone();
 399
 400        cx.spawn(async move |this, cx| {
 401            if let Some(prediction) = request_task.await? {
 402                this.update(cx, |this, cx| {
 403                    let project_state = this
 404                        .projects
 405                        .get_mut(&project.entity_id())
 406                        .context("Project not found")?;
 407
 408                    let new_prediction = CurrentEditPrediction {
 409                        requested_by_buffer_id: buffer.entity_id(),
 410                        prediction: prediction,
 411                    };
 412
 413                    if project_state
 414                        .current_prediction
 415                        .as_ref()
 416                        .is_none_or(|old_prediction| {
 417                            new_prediction
 418                                .should_replace_prediction(&old_prediction, buffer.read(cx))
 419                        })
 420                    {
 421                        project_state.current_prediction = Some(new_prediction);
 422                    }
 423                    anyhow::Ok(())
 424                })??;
 425            }
 426            Ok(())
 427        })
 428    }
 429
 430    fn request_prediction(
 431        &mut self,
 432        project: &Entity<Project>,
 433        buffer: &Entity<Buffer>,
 434        position: language::Anchor,
 435        cx: &mut Context<Self>,
 436    ) -> Task<Result<Option<EditPrediction>>> {
 437        let project_state = self.projects.get(&project.entity_id());
 438
 439        let index_state = project_state.map(|state| {
 440            state
 441                .syntax_index
 442                .read_with(cx, |index, _cx| index.state().clone())
 443        });
 444        let options = self.options.clone();
 445        let snapshot = buffer.read(cx).snapshot();
 446        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx).into()) else {
 447            return Task::ready(Err(anyhow!("No file path for excerpt")));
 448        };
 449        let client = self.client.clone();
 450        let llm_token = self.llm_token.clone();
 451        let app_version = AppVersion::global(cx);
 452        let worktree_snapshots = project
 453            .read(cx)
 454            .worktrees(cx)
 455            .map(|worktree| worktree.read(cx).snapshot())
 456            .collect::<Vec<_>>();
 457        let debug_tx = self.debug_tx.clone();
 458
 459        let events = project_state
 460            .map(|state| {
 461                state
 462                    .events
 463                    .iter()
 464                    .filter_map(|event| match event {
 465                        Event::BufferChange {
 466                            old_snapshot,
 467                            new_snapshot,
 468                            ..
 469                        } => {
 470                            let path = new_snapshot.file().map(|f| f.full_path(cx));
 471
 472                            let old_path = old_snapshot.file().and_then(|f| {
 473                                let old_path = f.full_path(cx);
 474                                if Some(&old_path) != path.as_ref() {
 475                                    Some(old_path)
 476                                } else {
 477                                    None
 478                                }
 479                            });
 480
 481                            // TODO [zeta2] move to bg?
 482                            let diff =
 483                                language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
 484
 485                            if path == old_path && diff.is_empty() {
 486                                None
 487                            } else {
 488                                Some(predict_edits_v3::Event::BufferChange {
 489                                    old_path,
 490                                    path,
 491                                    diff,
 492                                    //todo: Actually detect if this edit was predicted or not
 493                                    predicted: false,
 494                                })
 495                            }
 496                        }
 497                    })
 498                    .collect::<Vec<_>>()
 499            })
 500            .unwrap_or_default();
 501
 502        let diagnostics = snapshot.diagnostic_sets().clone();
 503
 504        let request_task = cx.background_spawn({
 505            let snapshot = snapshot.clone();
 506            let buffer = buffer.clone();
 507            async move {
 508                let index_state = if let Some(index_state) = index_state {
 509                    Some(index_state.lock_owned().await)
 510                } else {
 511                    None
 512                };
 513
 514                let cursor_offset = position.to_offset(&snapshot);
 515                let cursor_point = cursor_offset.to_point(&snapshot);
 516
 517                let before_retrieval = chrono::Utc::now();
 518
 519                let Some(context) = EditPredictionContext::gather_context(
 520                    cursor_point,
 521                    &snapshot,
 522                    &options.excerpt,
 523                    index_state.as_deref(),
 524                ) else {
 525                    return Ok(None);
 526                };
 527
 528                let debug_context = if let Some(debug_tx) = debug_tx {
 529                    Some((debug_tx, context.clone()))
 530                } else {
 531                    None
 532                };
 533
 534                let (diagnostic_groups, diagnostic_groups_truncated) =
 535                    Self::gather_nearby_diagnostics(
 536                        cursor_offset,
 537                        &diagnostics,
 538                        &snapshot,
 539                        options.max_diagnostic_bytes,
 540                    );
 541
 542                let request = make_cloud_request(
 543                    excerpt_path,
 544                    context,
 545                    events,
 546                    // TODO data collection
 547                    false,
 548                    diagnostic_groups,
 549                    diagnostic_groups_truncated,
 550                    None,
 551                    debug_context.is_some(),
 552                    &worktree_snapshots,
 553                    index_state.as_deref(),
 554                    Some(options.max_prompt_bytes),
 555                    options.prompt_format,
 556                );
 557
 558                let retrieval_time = chrono::Utc::now() - before_retrieval;
 559                let response = Self::perform_request(client, llm_token, app_version, request).await;
 560
 561                if let Some((debug_tx, context)) = debug_context {
 562                    debug_tx
 563                        .unbounded_send(response.as_ref().map_err(|err| err.to_string()).and_then(
 564                            |response| {
 565                                let Some(request) =
 566                                    some_or_debug_panic(response.0.debug_info.clone())
 567                                else {
 568                                    return Err("Missing debug info".to_string());
 569                                };
 570                                Ok(PredictionDebugInfo {
 571                                    context,
 572                                    request,
 573                                    retrieval_time,
 574                                    buffer: buffer.downgrade(),
 575                                    position,
 576                                })
 577                            },
 578                        ))
 579                        .ok();
 580                }
 581
 582                anyhow::Ok(Some(response?))
 583            }
 584        });
 585
 586        let buffer = buffer.clone();
 587
 588        cx.spawn({
 589            let project = project.clone();
 590            async move |this, cx| {
 591                match request_task.await {
 592                    Ok(Some((response, usage))) => {
 593                        if let Some(usage) = usage {
 594                            this.update(cx, |this, cx| {
 595                                this.user_store.update(cx, |user_store, cx| {
 596                                    user_store.update_edit_prediction_usage(usage, cx);
 597                                });
 598                            })
 599                            .ok();
 600                        }
 601
 602                        let prediction = EditPrediction::from_response(
 603                            response, &snapshot, &buffer, &project, cx,
 604                        )
 605                        .await;
 606
 607                        // TODO telemetry: duration, etc
 608                        Ok(prediction)
 609                    }
 610                    Ok(None) => Ok(None),
 611                    Err(err) => {
 612                        if err.is::<ZedUpdateRequiredError>() {
 613                            cx.update(|cx| {
 614                                this.update(cx, |this, _cx| {
 615                                    this.update_required = true;
 616                                })
 617                                .ok();
 618
 619                                let error_message: SharedString = err.to_string().into();
 620                                show_app_notification(
 621                                    NotificationId::unique::<ZedUpdateRequiredError>(),
 622                                    cx,
 623                                    move |cx| {
 624                                        cx.new(|cx| {
 625                                            ErrorMessagePrompt::new(error_message.clone(), cx)
 626                                                .with_link_button(
 627                                                    "Update Zed",
 628                                                    "https://zed.dev/releases",
 629                                                )
 630                                        })
 631                                    },
 632                                );
 633                            })
 634                            .ok();
 635                        }
 636
 637                        Err(err)
 638                    }
 639                }
 640            }
 641        })
 642    }
 643
 644    async fn perform_request(
 645        client: Arc<Client>,
 646        llm_token: LlmApiToken,
 647        app_version: SemanticVersion,
 648        request: predict_edits_v3::PredictEditsRequest,
 649    ) -> Result<(
 650        predict_edits_v3::PredictEditsResponse,
 651        Option<EditPredictionUsage>,
 652    )> {
 653        let http_client = client.http_client();
 654        let mut token = llm_token.acquire(&client).await?;
 655        let mut did_retry = false;
 656
 657        loop {
 658            let request_builder = http_client::Request::builder().method(Method::POST);
 659            let request_builder =
 660                if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 661                    request_builder.uri(predict_edits_url)
 662                } else {
 663                    request_builder.uri(
 664                        http_client
 665                            .build_zed_llm_url("/predict_edits/v3", &[])?
 666                            .as_ref(),
 667                    )
 668                };
 669            let request = request_builder
 670                .header("Content-Type", "application/json")
 671                .header("Authorization", format!("Bearer {}", token))
 672                .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 673                .body(serde_json::to_string(&request)?.into())?;
 674
 675            let mut response = http_client.send(request).await?;
 676
 677            if let Some(minimum_required_version) = response
 678                .headers()
 679                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 680                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 681            {
 682                anyhow::ensure!(
 683                    app_version >= minimum_required_version,
 684                    ZedUpdateRequiredError {
 685                        minimum_version: minimum_required_version
 686                    }
 687                );
 688            }
 689
 690            if response.status().is_success() {
 691                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 692
 693                let mut body = Vec::new();
 694                response.body_mut().read_to_end(&mut body).await?;
 695                return Ok((serde_json::from_slice(&body)?, usage));
 696            } else if !did_retry
 697                && response
 698                    .headers()
 699                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 700                    .is_some()
 701            {
 702                did_retry = true;
 703                token = llm_token.refresh(&client).await?;
 704            } else {
 705                let mut body = String::new();
 706                response.body_mut().read_to_string(&mut body).await?;
 707                anyhow::bail!(
 708                    "error predicting edits.\nStatus: {:?}\nBody: {}",
 709                    response.status(),
 710                    body
 711                );
 712            }
 713        }
 714    }
 715
 716    fn gather_nearby_diagnostics(
 717        cursor_offset: usize,
 718        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
 719        snapshot: &BufferSnapshot,
 720        max_diagnostics_bytes: usize,
 721    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
 722        // TODO: Could make this more efficient
 723        let mut diagnostic_groups = Vec::new();
 724        for (language_server_id, diagnostics) in diagnostic_sets {
 725            let mut groups = Vec::new();
 726            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
 727            diagnostic_groups.extend(
 728                groups
 729                    .into_iter()
 730                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
 731            );
 732        }
 733
 734        // sort by proximity to cursor
 735        diagnostic_groups.sort_by_key(|group| {
 736            let range = &group.entries[group.primary_ix].range;
 737            if range.start >= cursor_offset {
 738                range.start - cursor_offset
 739            } else if cursor_offset >= range.end {
 740                cursor_offset - range.end
 741            } else {
 742                (cursor_offset - range.start).min(range.end - cursor_offset)
 743            }
 744        });
 745
 746        let mut results = Vec::new();
 747        let mut diagnostic_groups_truncated = false;
 748        let mut diagnostics_byte_count = 0;
 749        for group in diagnostic_groups {
 750            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
 751            diagnostics_byte_count += raw_value.get().len();
 752            if diagnostics_byte_count > max_diagnostics_bytes {
 753                diagnostic_groups_truncated = true;
 754                break;
 755            }
 756            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
 757        }
 758
 759        (results, diagnostic_groups_truncated)
 760    }
 761
 762    // TODO: Dedupe with similar code in request_prediction?
 763    pub fn cloud_request_for_zeta_cli(
 764        &mut self,
 765        project: &Entity<Project>,
 766        buffer: &Entity<Buffer>,
 767        position: language::Anchor,
 768        cx: &mut Context<Self>,
 769    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
 770        let project_state = self.projects.get(&project.entity_id());
 771
 772        let index_state = project_state.map(|state| {
 773            state
 774                .syntax_index
 775                .read_with(cx, |index, _cx| index.state().clone())
 776        });
 777        let options = self.options.clone();
 778        let snapshot = buffer.read(cx).snapshot();
 779        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
 780            return Task::ready(Err(anyhow!("No file path for excerpt")));
 781        };
 782        let worktree_snapshots = project
 783            .read(cx)
 784            .worktrees(cx)
 785            .map(|worktree| worktree.read(cx).snapshot())
 786            .collect::<Vec<_>>();
 787
 788        cx.background_spawn(async move {
 789            let index_state = if let Some(index_state) = index_state {
 790                Some(index_state.lock_owned().await)
 791            } else {
 792                None
 793            };
 794
 795            let cursor_point = position.to_point(&snapshot);
 796
 797            let debug_info = true;
 798            EditPredictionContext::gather_context(
 799                cursor_point,
 800                &snapshot,
 801                &options.excerpt,
 802                index_state.as_deref(),
 803            )
 804            .context("Failed to select excerpt")
 805            .map(|context| {
 806                make_cloud_request(
 807                    excerpt_path.into(),
 808                    context,
 809                    // TODO pass everything
 810                    Vec::new(),
 811                    false,
 812                    Vec::new(),
 813                    false,
 814                    None,
 815                    debug_info,
 816                    &worktree_snapshots,
 817                    index_state.as_deref(),
 818                    Some(options.max_prompt_bytes),
 819                    options.prompt_format,
 820                )
 821            })
 822        })
 823    }
 824
 825    pub fn wait_for_initial_indexing(
 826        &mut self,
 827        project: &Entity<Project>,
 828        cx: &mut App,
 829    ) -> Task<Result<()>> {
 830        let zeta_project = self.get_or_init_zeta_project(project, cx);
 831        zeta_project
 832            .syntax_index
 833            .read(cx)
 834            .wait_for_initial_file_indexing(cx)
 835    }
 836}
 837
 838#[derive(Error, Debug)]
 839#[error(
 840    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
 841)]
 842pub struct ZedUpdateRequiredError {
 843    minimum_version: SemanticVersion,
 844}
 845
 846fn make_cloud_request(
 847    excerpt_path: Arc<Path>,
 848    context: EditPredictionContext,
 849    events: Vec<predict_edits_v3::Event>,
 850    can_collect_data: bool,
 851    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
 852    diagnostic_groups_truncated: bool,
 853    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
 854    debug_info: bool,
 855    worktrees: &Vec<worktree::Snapshot>,
 856    index_state: Option<&SyntaxIndexState>,
 857    prompt_max_bytes: Option<usize>,
 858    prompt_format: PromptFormat,
 859) -> predict_edits_v3::PredictEditsRequest {
 860    let mut signatures = Vec::new();
 861    let mut declaration_to_signature_index = HashMap::default();
 862    let mut referenced_declarations = Vec::new();
 863
 864    for snippet in context.declarations {
 865        let project_entry_id = snippet.declaration.project_entry_id();
 866        let Some(path) = worktrees.iter().find_map(|worktree| {
 867            worktree.entry_for_id(project_entry_id).map(|entry| {
 868                let mut full_path = RelPathBuf::new();
 869                full_path.push(worktree.root_name());
 870                full_path.push(&entry.path);
 871                full_path
 872            })
 873        }) else {
 874            continue;
 875        };
 876
 877        let parent_index = index_state.and_then(|index_state| {
 878            snippet.declaration.parent().and_then(|parent| {
 879                add_signature(
 880                    parent,
 881                    &mut declaration_to_signature_index,
 882                    &mut signatures,
 883                    index_state,
 884                )
 885            })
 886        });
 887
 888        let (text, text_is_truncated) = snippet.declaration.item_text();
 889        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
 890            path: path.as_std_path().into(),
 891            text: text.into(),
 892            range: snippet.declaration.item_range(),
 893            text_is_truncated,
 894            signature_range: snippet.declaration.signature_range_in_item_text(),
 895            parent_index,
 896            score_components: snippet.score_components,
 897            signature_score: snippet.scores.signature,
 898            declaration_score: snippet.scores.declaration,
 899        });
 900    }
 901
 902    let excerpt_parent = index_state.and_then(|index_state| {
 903        context
 904            .excerpt
 905            .parent_declarations
 906            .last()
 907            .and_then(|(parent, _)| {
 908                add_signature(
 909                    *parent,
 910                    &mut declaration_to_signature_index,
 911                    &mut signatures,
 912                    index_state,
 913                )
 914            })
 915    });
 916
 917    predict_edits_v3::PredictEditsRequest {
 918        excerpt_path,
 919        excerpt: context.excerpt_text.body,
 920        excerpt_range: context.excerpt.range,
 921        cursor_offset: context.cursor_offset_in_excerpt,
 922        referenced_declarations,
 923        signatures,
 924        excerpt_parent,
 925        events,
 926        can_collect_data,
 927        diagnostic_groups,
 928        diagnostic_groups_truncated,
 929        git_info,
 930        debug_info,
 931        prompt_max_bytes,
 932        prompt_format,
 933    }
 934}
 935
 936fn add_signature(
 937    declaration_id: DeclarationId,
 938    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
 939    signatures: &mut Vec<Signature>,
 940    index: &SyntaxIndexState,
 941) -> Option<usize> {
 942    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
 943        return Some(*signature_index);
 944    }
 945    let Some(parent_declaration) = index.declaration(declaration_id) else {
 946        log::error!("bug: missing parent declaration");
 947        return None;
 948    };
 949    let parent_index = parent_declaration.parent().and_then(|parent| {
 950        add_signature(parent, declaration_to_signature_index, signatures, index)
 951    });
 952    let (text, text_is_truncated) = parent_declaration.signature_text();
 953    let signature_index = signatures.len();
 954    signatures.push(Signature {
 955        text: text.into(),
 956        text_is_truncated,
 957        parent_index,
 958        range: parent_declaration.signature_range(),
 959    });
 960    declaration_to_signature_index.insert(declaration_id, signature_index);
 961    Some(signature_index)
 962}
 963
 964#[cfg(test)]
 965mod tests {
 966    use std::{
 967        path::{Path, PathBuf},
 968        sync::Arc,
 969    };
 970
 971    use client::UserStore;
 972    use clock::FakeSystemClock;
 973    use cloud_llm_client::predict_edits_v3;
 974    use futures::{
 975        AsyncReadExt, StreamExt,
 976        channel::{mpsc, oneshot},
 977    };
 978    use gpui::{
 979        Entity, TestAppContext,
 980        http_client::{FakeHttpClient, Response},
 981        prelude::*,
 982    };
 983    use indoc::indoc;
 984    use language::{LanguageServerId, OffsetRangeExt as _};
 985    use pretty_assertions::{assert_eq, assert_matches};
 986    use project::{FakeFs, Project};
 987    use serde_json::json;
 988    use settings::SettingsStore;
 989    use util::path;
 990    use uuid::Uuid;
 991
 992    use crate::{BufferEditPrediction, Zeta};
 993
 994    #[gpui::test]
 995    async fn test_current_state(cx: &mut TestAppContext) {
 996        let (zeta, mut req_rx) = init_test(cx);
 997        let fs = FakeFs::new(cx.executor());
 998        fs.insert_tree(
 999            "/root",
1000            json!({
1001                "1.txt": "Hello!\nHow\nBye",
1002                "2.txt": "Hola!\nComo\nAdios"
1003            }),
1004        )
1005        .await;
1006        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1007
1008        zeta.update(cx, |zeta, cx| {
1009            zeta.register_project(&project, cx);
1010        });
1011
1012        let buffer1 = project
1013            .update(cx, |project, cx| {
1014                let path = project.find_project_path(path!("root/1.txt"), cx).unwrap();
1015                project.open_buffer(path, cx)
1016            })
1017            .await
1018            .unwrap();
1019        let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
1020        let position = snapshot1.anchor_before(language::Point::new(1, 3));
1021
1022        // Prediction for current file
1023
1024        let prediction_task = zeta.update(cx, |zeta, cx| {
1025            zeta.refresh_prediction(&project, &buffer1, position, cx)
1026        });
1027        let (_request, respond_tx) = req_rx.next().await.unwrap();
1028        respond_tx
1029            .send(predict_edits_v3::PredictEditsResponse {
1030                request_id: Uuid::new_v4(),
1031                edits: vec![predict_edits_v3::Edit {
1032                    path: Path::new(path!("root/1.txt")).into(),
1033                    range: 0..snapshot1.len(),
1034                    content: "Hello!\nHow are you?\nBye".into(),
1035                }],
1036                debug_info: None,
1037            })
1038            .unwrap();
1039        prediction_task.await.unwrap();
1040
1041        zeta.read_with(cx, |zeta, cx| {
1042            let prediction = zeta
1043                .current_prediction_for_buffer(&buffer1, &project, cx)
1044                .unwrap();
1045            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1046        });
1047
1048        // Prediction for another file
1049
1050        let prediction_task = zeta.update(cx, |zeta, cx| {
1051            zeta.refresh_prediction(&project, &buffer1, position, cx)
1052        });
1053        let (_request, respond_tx) = req_rx.next().await.unwrap();
1054        respond_tx
1055            .send(predict_edits_v3::PredictEditsResponse {
1056                request_id: Uuid::new_v4(),
1057                edits: vec![predict_edits_v3::Edit {
1058                    path: Path::new(path!("root/2.txt")).into(),
1059                    range: 0..snapshot1.len(),
1060                    content: "Hola!\nComo estas?\nAdios".into(),
1061                }],
1062                debug_info: None,
1063            })
1064            .unwrap();
1065        prediction_task.await.unwrap();
1066
1067        zeta.read_with(cx, |zeta, cx| {
1068            let prediction = zeta
1069                .current_prediction_for_buffer(&buffer1, &project, cx)
1070                .unwrap();
1071            assert_matches!(
1072                prediction,
1073                BufferEditPrediction::Jump { prediction } if prediction.path.as_ref() == Path::new(path!("root/2.txt"))
1074            );
1075        });
1076
1077        let buffer2 = project
1078            .update(cx, |project, cx| {
1079                let path = project.find_project_path(path!("root/2.txt"), cx).unwrap();
1080                project.open_buffer(path, cx)
1081            })
1082            .await
1083            .unwrap();
1084
1085        zeta.read_with(cx, |zeta, cx| {
1086            let prediction = zeta
1087                .current_prediction_for_buffer(&buffer2, &project, cx)
1088                .unwrap();
1089            assert_matches!(prediction, BufferEditPrediction::Local { .. });
1090        });
1091    }
1092
1093    #[gpui::test]
1094    async fn test_simple_request(cx: &mut TestAppContext) {
1095        let (zeta, mut req_rx) = init_test(cx);
1096        let fs = FakeFs::new(cx.executor());
1097        fs.insert_tree(
1098            "/root",
1099            json!({
1100                "foo.md":  "Hello!\nHow\nBye"
1101            }),
1102        )
1103        .await;
1104        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1105
1106        let buffer = project
1107            .update(cx, |project, cx| {
1108                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1109                project.open_buffer(path, cx)
1110            })
1111            .await
1112            .unwrap();
1113        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1114        let position = snapshot.anchor_before(language::Point::new(1, 3));
1115
1116        let prediction_task = zeta.update(cx, |zeta, cx| {
1117            zeta.request_prediction(&project, &buffer, position, cx)
1118        });
1119
1120        let (request, respond_tx) = req_rx.next().await.unwrap();
1121        assert_eq!(
1122            request.excerpt_path.as_ref(),
1123            Path::new(path!("root/foo.md"))
1124        );
1125        assert_eq!(request.cursor_offset, 10);
1126
1127        respond_tx
1128            .send(predict_edits_v3::PredictEditsResponse {
1129                request_id: Uuid::new_v4(),
1130                edits: vec![predict_edits_v3::Edit {
1131                    path: Path::new(path!("root/foo.md")).into(),
1132                    range: 0..snapshot.len(),
1133                    content: "Hello!\nHow are you?\nBye".into(),
1134                }],
1135                debug_info: None,
1136            })
1137            .unwrap();
1138
1139        let prediction = prediction_task.await.unwrap().unwrap();
1140
1141        assert_eq!(prediction.edits.len(), 1);
1142        assert_eq!(
1143            prediction.edits[0].0.to_point(&snapshot).start,
1144            language::Point::new(1, 3)
1145        );
1146        assert_eq!(prediction.edits[0].1, " are you?");
1147    }
1148
1149    #[gpui::test]
1150    async fn test_request_events(cx: &mut TestAppContext) {
1151        let (zeta, mut req_rx) = init_test(cx);
1152        let fs = FakeFs::new(cx.executor());
1153        fs.insert_tree(
1154            "/root",
1155            json!({
1156                "foo.md": "Hello!\n\nBye"
1157            }),
1158        )
1159        .await;
1160        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1161
1162        let buffer = project
1163            .update(cx, |project, cx| {
1164                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1165                project.open_buffer(path, cx)
1166            })
1167            .await
1168            .unwrap();
1169
1170        zeta.update(cx, |zeta, cx| {
1171            zeta.register_buffer(&buffer, &project, cx);
1172        });
1173
1174        buffer.update(cx, |buffer, cx| {
1175            buffer.edit(vec![(7..7, "How")], None, cx);
1176        });
1177
1178        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1179        let position = snapshot.anchor_before(language::Point::new(1, 3));
1180
1181        let prediction_task = zeta.update(cx, |zeta, cx| {
1182            zeta.request_prediction(&project, &buffer, position, cx)
1183        });
1184
1185        let (request, respond_tx) = req_rx.next().await.unwrap();
1186
1187        assert_eq!(request.events.len(), 1);
1188        assert_eq!(
1189            request.events[0],
1190            predict_edits_v3::Event::BufferChange {
1191                path: Some(PathBuf::from(path!("root/foo.md"))),
1192                old_path: None,
1193                diff: indoc! {"
1194                        @@ -1,3 +1,3 @@
1195                         Hello!
1196                        -
1197                        +How
1198                         Bye
1199                    "}
1200                .to_string(),
1201                predicted: false
1202            }
1203        );
1204
1205        respond_tx
1206            .send(predict_edits_v3::PredictEditsResponse {
1207                request_id: Uuid::new_v4(),
1208                edits: vec![predict_edits_v3::Edit {
1209                    path: Path::new(path!("root/foo.md")).into(),
1210                    range: 0..snapshot.len(),
1211                    content: "Hello!\nHow are you?\nBye".into(),
1212                }],
1213                debug_info: None,
1214            })
1215            .unwrap();
1216
1217        let prediction = prediction_task.await.unwrap().unwrap();
1218
1219        assert_eq!(prediction.edits.len(), 1);
1220        assert_eq!(
1221            prediction.edits[0].0.to_point(&snapshot).start,
1222            language::Point::new(1, 3)
1223        );
1224        assert_eq!(prediction.edits[0].1, " are you?");
1225    }
1226
1227    #[gpui::test]
1228    async fn test_request_diagnostics(cx: &mut TestAppContext) {
1229        let (zeta, mut req_rx) = init_test(cx);
1230        let fs = FakeFs::new(cx.executor());
1231        fs.insert_tree(
1232            "/root",
1233            json!({
1234                "foo.md": "Hello!\nBye"
1235            }),
1236        )
1237        .await;
1238        let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
1239
1240        let path_to_buffer_uri = lsp::Uri::from_file_path(path!("/root/foo.md")).unwrap();
1241        let diagnostic = lsp::Diagnostic {
1242            range: lsp::Range::new(lsp::Position::new(1, 1), lsp::Position::new(1, 5)),
1243            severity: Some(lsp::DiagnosticSeverity::ERROR),
1244            message: "\"Hello\" deprecated. Use \"Hi\" instead".to_string(),
1245            ..Default::default()
1246        };
1247
1248        project.update(cx, |project, cx| {
1249            project.lsp_store().update(cx, |lsp_store, cx| {
1250                // Create some diagnostics
1251                lsp_store
1252                    .update_diagnostics(
1253                        LanguageServerId(0),
1254                        lsp::PublishDiagnosticsParams {
1255                            uri: path_to_buffer_uri.clone(),
1256                            diagnostics: vec![diagnostic],
1257                            version: None,
1258                        },
1259                        None,
1260                        language::DiagnosticSourceKind::Pushed,
1261                        &[],
1262                        cx,
1263                    )
1264                    .unwrap();
1265            });
1266        });
1267
1268        let buffer = project
1269            .update(cx, |project, cx| {
1270                let path = project.find_project_path(path!("root/foo.md"), cx).unwrap();
1271                project.open_buffer(path, cx)
1272            })
1273            .await
1274            .unwrap();
1275
1276        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1277        let position = snapshot.anchor_before(language::Point::new(0, 0));
1278
1279        let _prediction_task = zeta.update(cx, |zeta, cx| {
1280            zeta.request_prediction(&project, &buffer, position, cx)
1281        });
1282
1283        let (request, _respond_tx) = req_rx.next().await.unwrap();
1284
1285        assert_eq!(request.diagnostic_groups.len(), 1);
1286        let value = serde_json::from_str::<serde_json::Value>(request.diagnostic_groups[0].0.get())
1287            .unwrap();
1288        // We probably don't need all of this. TODO define a specific diagnostic type in predict_edits_v3
1289        assert_eq!(
1290            value,
1291            json!({
1292                "entries": [{
1293                    "range": {
1294                        "start": 8,
1295                        "end": 10
1296                    },
1297                    "diagnostic": {
1298                        "source": null,
1299                        "code": null,
1300                        "code_description": null,
1301                        "severity": 1,
1302                        "message": "\"Hello\" deprecated. Use \"Hi\" instead",
1303                        "markdown": null,
1304                        "group_id": 0,
1305                        "is_primary": true,
1306                        "is_disk_based": false,
1307                        "is_unnecessary": false,
1308                        "source_kind": "Pushed",
1309                        "data": null,
1310                        "underline": true
1311                    }
1312                }],
1313                "primary_ix": 0
1314            })
1315        );
1316    }
1317
1318    fn init_test(
1319        cx: &mut TestAppContext,
1320    ) -> (
1321        Entity<Zeta>,
1322        mpsc::UnboundedReceiver<(
1323            predict_edits_v3::PredictEditsRequest,
1324            oneshot::Sender<predict_edits_v3::PredictEditsResponse>,
1325        )>,
1326    ) {
1327        cx.update(move |cx| {
1328            let settings_store = SettingsStore::test(cx);
1329            cx.set_global(settings_store);
1330            language::init(cx);
1331            Project::init_settings(cx);
1332
1333            let (req_tx, req_rx) = mpsc::unbounded();
1334
1335            let http_client = FakeHttpClient::create({
1336                move |req| {
1337                    let uri = req.uri().path().to_string();
1338                    let mut body = req.into_body();
1339                    let req_tx = req_tx.clone();
1340                    async move {
1341                        let resp = match uri.as_str() {
1342                            "/client/llm_tokens" => serde_json::to_string(&json!({
1343                                "token": "test"
1344                            }))
1345                            .unwrap(),
1346                            "/predict_edits/v3" => {
1347                                let mut buf = Vec::new();
1348                                body.read_to_end(&mut buf).await.ok();
1349                                let req = serde_json::from_slice(&buf).unwrap();
1350
1351                                let (res_tx, res_rx) = oneshot::channel();
1352                                req_tx.unbounded_send((req, res_tx)).unwrap();
1353                                serde_json::to_string(&res_rx.await.unwrap()).unwrap()
1354                            }
1355                            _ => {
1356                                panic!("Unexpected path: {}", uri)
1357                            }
1358                        };
1359
1360                        Ok(Response::builder().body(resp.into()).unwrap())
1361                    }
1362                }
1363            });
1364
1365            let client = client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx);
1366            client.cloud_client().set_credentials(1, "test".into());
1367
1368            language_model::init(client.clone(), cx);
1369
1370            let user_store = cx.new(|cx| UserStore::new(client.clone(), cx));
1371            let zeta = Zeta::global(&client, &user_store, cx);
1372
1373            (zeta, req_rx)
1374        })
1375    }
1376}