zeta2.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, Signature};
   5use cloud_llm_client::{
   6    EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
   7};
   8use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
   9use edit_prediction_context::{
  10    DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
  11    SyntaxIndexState,
  12};
  13use futures::AsyncReadExt as _;
  14use gpui::http_client::Method;
  15use gpui::{
  16    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, http_client,
  17    prelude::*,
  18};
  19use language::{Anchor, Buffer, OffsetRangeExt as _, ToPoint};
  20use language::{BufferSnapshot, EditPreview};
  21use language_model::{LlmApiToken, RefreshLlmTokenListener};
  22use project::Project;
  23use release_channel::AppVersion;
  24use std::cmp;
  25use std::collections::{HashMap, VecDeque, hash_map};
  26use std::path::PathBuf;
  27use std::str::FromStr as _;
  28use std::time::{Duration, Instant};
  29use std::{ops::Range, sync::Arc};
  30use thiserror::Error;
  31use util::ResultExt as _;
  32use uuid::Uuid;
  33use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  34
  35const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  36
  37/// Maximum number of events to track.
  38const MAX_EVENT_COUNT: usize = 16;
  39
  40#[derive(Clone)]
  41struct ZetaGlobal(Entity<Zeta>);
  42
  43impl Global for ZetaGlobal {}
  44
  45pub struct Zeta {
  46    client: Arc<Client>,
  47    user_store: Entity<UserStore>,
  48    llm_token: LlmApiToken,
  49    _llm_token_subscription: Subscription,
  50    projects: HashMap<EntityId, ZetaProject>,
  51    pub excerpt_options: EditPredictionExcerptOptions,
  52    update_required: bool,
  53}
  54
  55struct ZetaProject {
  56    syntax_index: Entity<SyntaxIndex>,
  57    events: VecDeque<Event>,
  58    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
  59}
  60
  61struct RegisteredBuffer {
  62    snapshot: BufferSnapshot,
  63    _subscriptions: [gpui::Subscription; 2],
  64}
  65
  66#[derive(Clone)]
  67pub enum Event {
  68    BufferChange {
  69        old_snapshot: BufferSnapshot,
  70        new_snapshot: BufferSnapshot,
  71        timestamp: Instant,
  72    },
  73}
  74
  75impl Zeta {
  76    pub fn global(
  77        client: &Arc<Client>,
  78        user_store: &Entity<UserStore>,
  79        cx: &mut App,
  80    ) -> Entity<Self> {
  81        cx.try_global::<ZetaGlobal>()
  82            .map(|global| global.0.clone())
  83            .unwrap_or_else(|| {
  84                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
  85                cx.set_global(ZetaGlobal(zeta.clone()));
  86                zeta
  87            })
  88    }
  89
  90    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
  91        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
  92
  93        Self {
  94            projects: HashMap::new(),
  95            client,
  96            user_store,
  97            excerpt_options: EditPredictionExcerptOptions {
  98                max_bytes: 512,
  99                min_bytes: 128,
 100                target_before_cursor_over_total_bytes: 0.5,
 101            },
 102            llm_token: LlmApiToken::default(),
 103            _llm_token_subscription: cx.subscribe(
 104                &refresh_llm_token_listener,
 105                |this, _listener, _event, cx| {
 106                    let client = this.client.clone();
 107                    let llm_token = this.llm_token.clone();
 108                    cx.spawn(async move |_this, _cx| {
 109                        llm_token.refresh(&client).await?;
 110                        anyhow::Ok(())
 111                    })
 112                    .detach_and_log_err(cx);
 113                },
 114            ),
 115            update_required: false,
 116        }
 117    }
 118
 119    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 120        self.user_store.read(cx).edit_prediction_usage()
 121    }
 122
 123    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 124        self.get_or_init_zeta_project(project, cx);
 125    }
 126
 127    pub fn register_buffer(
 128        &mut self,
 129        buffer: &Entity<Buffer>,
 130        project: &Entity<Project>,
 131        cx: &mut Context<Self>,
 132    ) {
 133        let zeta_project = self.get_or_init_zeta_project(project, cx);
 134        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 135    }
 136
 137    fn get_or_init_zeta_project(
 138        &mut self,
 139        project: &Entity<Project>,
 140        cx: &mut App,
 141    ) -> &mut ZetaProject {
 142        self.projects
 143            .entry(project.entity_id())
 144            .or_insert_with(|| ZetaProject {
 145                syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
 146                events: VecDeque::new(),
 147                registered_buffers: HashMap::new(),
 148            })
 149    }
 150
 151    fn register_buffer_impl<'a>(
 152        zeta_project: &'a mut ZetaProject,
 153        buffer: &Entity<Buffer>,
 154        project: &Entity<Project>,
 155        cx: &mut Context<Self>,
 156    ) -> &'a mut RegisteredBuffer {
 157        let buffer_id = buffer.entity_id();
 158        match zeta_project.registered_buffers.entry(buffer_id) {
 159            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 160            hash_map::Entry::Vacant(entry) => {
 161                let snapshot = buffer.read(cx).snapshot();
 162                let project_entity_id = project.entity_id();
 163                entry.insert(RegisteredBuffer {
 164                    snapshot,
 165                    _subscriptions: [
 166                        cx.subscribe(buffer, {
 167                            let project = project.downgrade();
 168                            move |this, buffer, event, cx| {
 169                                if let language::BufferEvent::Edited = event
 170                                    && let Some(project) = project.upgrade()
 171                                {
 172                                    this.report_changes_for_buffer(&buffer, &project, cx);
 173                                }
 174                            }
 175                        }),
 176                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 177                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 178                            else {
 179                                return;
 180                            };
 181                            zeta_project.registered_buffers.remove(&buffer_id);
 182                        }),
 183                    ],
 184                })
 185            }
 186        }
 187    }
 188
 189    fn report_changes_for_buffer(
 190        &mut self,
 191        buffer: &Entity<Buffer>,
 192        project: &Entity<Project>,
 193        cx: &mut Context<Self>,
 194    ) -> BufferSnapshot {
 195        let zeta_project = self.get_or_init_zeta_project(project, cx);
 196        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 197
 198        let new_snapshot = buffer.read(cx).snapshot();
 199        if new_snapshot.version != registered_buffer.snapshot.version {
 200            let old_snapshot =
 201                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 202            Self::push_event(
 203                zeta_project,
 204                Event::BufferChange {
 205                    old_snapshot,
 206                    new_snapshot: new_snapshot.clone(),
 207                    timestamp: Instant::now(),
 208                },
 209            );
 210        }
 211
 212        new_snapshot
 213    }
 214
 215    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
 216        let events = &mut zeta_project.events;
 217
 218        if let Some(Event::BufferChange {
 219            new_snapshot: last_new_snapshot,
 220            timestamp: last_timestamp,
 221            ..
 222        }) = events.back_mut()
 223        {
 224            // Coalesce edits for the same buffer when they happen one after the other.
 225            let Event::BufferChange {
 226                old_snapshot,
 227                new_snapshot,
 228                timestamp,
 229            } = &event;
 230
 231            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 232                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 233                && old_snapshot.version == last_new_snapshot.version
 234            {
 235                *last_new_snapshot = new_snapshot.clone();
 236                *last_timestamp = *timestamp;
 237                return;
 238            }
 239        }
 240
 241        if events.len() >= MAX_EVENT_COUNT {
 242            // These are halved instead of popping to improve prompt caching.
 243            events.drain(..MAX_EVENT_COUNT / 2);
 244        }
 245
 246        events.push_back(event);
 247    }
 248
 249    pub fn request_prediction(
 250        &mut self,
 251        project: &Entity<Project>,
 252        buffer: &Entity<Buffer>,
 253        position: language::Anchor,
 254        cx: &mut Context<Self>,
 255    ) -> Task<Result<Option<EditPrediction>>> {
 256        let project_state = self.projects.get(&project.entity_id());
 257
 258        let index_state = project_state.map(|state| {
 259            state
 260                .syntax_index
 261                .read_with(cx, |index, _cx| index.state().clone())
 262        });
 263        let excerpt_options = self.excerpt_options.clone();
 264        let snapshot = buffer.read(cx).snapshot();
 265        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
 266            return Task::ready(Err(anyhow!("No file path for excerpt")));
 267        };
 268        let client = self.client.clone();
 269        let llm_token = self.llm_token.clone();
 270        let app_version = AppVersion::global(cx);
 271        let worktree_snapshots = project
 272            .read(cx)
 273            .worktrees(cx)
 274            .map(|worktree| worktree.read(cx).snapshot())
 275            .collect::<Vec<_>>();
 276
 277        let request_task = cx.background_spawn({
 278            let snapshot = snapshot.clone();
 279            async move {
 280                let index_state = if let Some(index_state) = index_state {
 281                    Some(index_state.lock_owned().await)
 282                } else {
 283                    None
 284                };
 285
 286                let cursor_point = position.to_point(&snapshot);
 287
 288                // TODO: make this only true if debug view is open
 289                let debug_info = true;
 290
 291                let Some(request) = EditPredictionContext::gather_context(
 292                    cursor_point,
 293                    &snapshot,
 294                    &excerpt_options,
 295                    index_state.as_deref(),
 296                )
 297                .map(|context| {
 298                    make_cloud_request(
 299                        excerpt_path.clone(),
 300                        context,
 301                        // TODO pass everything
 302                        Vec::new(),
 303                        false,
 304                        Vec::new(),
 305                        None,
 306                        debug_info,
 307                        &worktree_snapshots,
 308                        index_state.as_deref(),
 309                    )
 310                }) else {
 311                    return Ok(None);
 312                };
 313
 314                anyhow::Ok(Some(
 315                    Self::perform_request(client, llm_token, app_version, request).await?,
 316                ))
 317            }
 318        });
 319
 320        let buffer = buffer.clone();
 321
 322        cx.spawn(async move |this, cx| {
 323            match request_task.await {
 324                Ok(Some((response, usage))) => {
 325                    log::debug!("predicted edits: {:?}", &response.edits);
 326
 327                    if let Some(usage) = usage {
 328                        this.update(cx, |this, cx| {
 329                            this.user_store.update(cx, |user_store, cx| {
 330                                user_store.update_edit_prediction_usage(usage, cx);
 331                            });
 332                        })
 333                        .ok();
 334                    }
 335
 336                    // TODO telemetry: duration, etc
 337
 338                    // TODO produce smaller edits by diffing against snapshot first
 339                    //
 340                    // Cloud returns entire snippets/excerpts ranges as they were included
 341                    // in the request, but we should display smaller edits to the user.
 342                    //
 343                    // We can do this by computing a diff of each one against the snapshot.
 344                    // Similar to zeta::Zeta::compute_edits, but per edit.
 345                    let edits = response
 346                        .edits
 347                        .into_iter()
 348                        .map(|edit| {
 349                            // TODO edits to different files
 350                            (
 351                                snapshot.anchor_before(edit.range.start)
 352                                    ..snapshot.anchor_before(edit.range.end),
 353                                edit.content,
 354                            )
 355                        })
 356                        .collect::<Vec<_>>()
 357                        .into();
 358
 359                    let Some((edits, snapshot, edit_preview_task)) =
 360                        buffer.read_with(cx, |buffer, cx| {
 361                            let new_snapshot = buffer.snapshot();
 362                            let edits: Arc<[_]> =
 363                                interpolate(&snapshot, &new_snapshot, edits)?.into();
 364                            Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 365                        })?
 366                    else {
 367                        return Ok(None);
 368                    };
 369
 370                    Ok(Some(EditPrediction {
 371                        id: EditPredictionId(response.request_id),
 372                        edits,
 373                        snapshot,
 374                        edit_preview: edit_preview_task.await,
 375                    }))
 376                }
 377                Ok(None) => Ok(None),
 378                Err(err) => {
 379                    if err.is::<ZedUpdateRequiredError>() {
 380                        cx.update(|cx| {
 381                            this.update(cx, |this, _cx| {
 382                                this.update_required = true;
 383                            })
 384                            .ok();
 385
 386                            let error_message: SharedString = err.to_string().into();
 387                            show_app_notification(
 388                                NotificationId::unique::<ZedUpdateRequiredError>(),
 389                                cx,
 390                                move |cx| {
 391                                    cx.new(|cx| {
 392                                        ErrorMessagePrompt::new(error_message.clone(), cx)
 393                                            .with_link_button(
 394                                                "Update Zed",
 395                                                "https://zed.dev/releases",
 396                                            )
 397                                    })
 398                                },
 399                            );
 400                        })
 401                        .ok();
 402                    }
 403
 404                    Err(err)
 405                }
 406            }
 407        })
 408    }
 409
 410    async fn perform_request(
 411        client: Arc<Client>,
 412        llm_token: LlmApiToken,
 413        app_version: SemanticVersion,
 414        request: predict_edits_v3::PredictEditsRequest,
 415    ) -> Result<(
 416        predict_edits_v3::PredictEditsResponse,
 417        Option<EditPredictionUsage>,
 418    )> {
 419        let http_client = client.http_client();
 420        let mut token = llm_token.acquire(&client).await?;
 421        let mut did_retry = false;
 422
 423        loop {
 424            let request_builder = http_client::Request::builder().method(Method::POST);
 425            let request_builder =
 426                if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 427                    request_builder.uri(predict_edits_url)
 428                } else {
 429                    request_builder.uri(
 430                        http_client
 431                            .build_zed_llm_url("/predict_edits/v3", &[])?
 432                            .as_ref(),
 433                    )
 434                };
 435            let request = request_builder
 436                .header("Content-Type", "application/json")
 437                .header("Authorization", format!("Bearer {}", token))
 438                .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 439                .body(serde_json::to_string(&request)?.into())?;
 440
 441            let mut response = http_client.send(request).await?;
 442
 443            if let Some(minimum_required_version) = response
 444                .headers()
 445                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 446                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 447            {
 448                anyhow::ensure!(
 449                    app_version >= minimum_required_version,
 450                    ZedUpdateRequiredError {
 451                        minimum_version: minimum_required_version
 452                    }
 453                );
 454            }
 455
 456            if response.status().is_success() {
 457                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 458
 459                let mut body = Vec::new();
 460                response.body_mut().read_to_end(&mut body).await?;
 461                return Ok((serde_json::from_slice(&body)?, usage));
 462            } else if !did_retry
 463                && response
 464                    .headers()
 465                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 466                    .is_some()
 467            {
 468                did_retry = true;
 469                token = llm_token.refresh(&client).await?;
 470            } else {
 471                let mut body = String::new();
 472                response.body_mut().read_to_string(&mut body).await?;
 473                anyhow::bail!(
 474                    "error predicting edits.\nStatus: {:?}\nBody: {}",
 475                    response.status(),
 476                    body
 477                );
 478            }
 479        }
 480    }
 481
 482    // TODO: Dedupe with similar code in request_prediction?
 483    pub fn cloud_request_for_zeta_cli(
 484        &mut self,
 485        project: &Entity<Project>,
 486        buffer: &Entity<Buffer>,
 487        position: language::Anchor,
 488        cx: &mut Context<Self>,
 489    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
 490        let project_state = self.projects.get(&project.entity_id());
 491
 492        let index_state = project_state.map(|state| {
 493            state
 494                .syntax_index
 495                .read_with(cx, |index, _cx| index.state().clone())
 496        });
 497        let excerpt_options = self.excerpt_options.clone();
 498        let snapshot = buffer.read(cx).snapshot();
 499        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
 500            return Task::ready(Err(anyhow!("No file path for excerpt")));
 501        };
 502        let worktree_snapshots = project
 503            .read(cx)
 504            .worktrees(cx)
 505            .map(|worktree| worktree.read(cx).snapshot())
 506            .collect::<Vec<_>>();
 507
 508        cx.background_spawn(async move {
 509            let index_state = if let Some(index_state) = index_state {
 510                Some(index_state.lock_owned().await)
 511            } else {
 512                None
 513            };
 514
 515            let cursor_point = position.to_point(&snapshot);
 516
 517            let debug_info = true;
 518            EditPredictionContext::gather_context(
 519                cursor_point,
 520                &snapshot,
 521                &excerpt_options,
 522                index_state.as_deref(),
 523            )
 524            .context("Failed to select excerpt")
 525            .map(|context| {
 526                make_cloud_request(
 527                    excerpt_path.clone(),
 528                    context,
 529                    // TODO pass everything
 530                    Vec::new(),
 531                    false,
 532                    Vec::new(),
 533                    None,
 534                    debug_info,
 535                    &worktree_snapshots,
 536                    index_state.as_deref(),
 537                )
 538            })
 539        })
 540    }
 541}
 542
 543#[derive(Error, Debug)]
 544#[error(
 545    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
 546)]
 547pub struct ZedUpdateRequiredError {
 548    minimum_version: SemanticVersion,
 549}
 550
 551pub struct ZetaEditPredictionProvider {
 552    zeta: Entity<Zeta>,
 553    current_prediction: Option<CurrentEditPrediction>,
 554    next_pending_prediction_id: usize,
 555    pending_predictions: ArrayVec<PendingPrediction, 2>,
 556    last_request_timestamp: Instant,
 557}
 558
 559impl ZetaEditPredictionProvider {
 560    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
 561
 562    pub fn new(
 563        project: Option<&Entity<Project>>,
 564        client: &Arc<Client>,
 565        user_store: &Entity<UserStore>,
 566        cx: &mut App,
 567    ) -> Self {
 568        let zeta = Zeta::global(client, user_store, cx);
 569        if let Some(project) = project {
 570            zeta.update(cx, |zeta, cx| {
 571                zeta.register_project(project, cx);
 572            });
 573        }
 574
 575        Self {
 576            zeta,
 577            current_prediction: None,
 578            next_pending_prediction_id: 0,
 579            pending_predictions: ArrayVec::new(),
 580            last_request_timestamp: Instant::now(),
 581        }
 582    }
 583}
 584
 585#[derive(Clone)]
 586struct CurrentEditPrediction {
 587    buffer_id: EntityId,
 588    prediction: EditPrediction,
 589}
 590
 591impl CurrentEditPrediction {
 592    fn should_replace_prediction(&self, old_prediction: &Self, snapshot: &BufferSnapshot) -> bool {
 593        if self.buffer_id != old_prediction.buffer_id {
 594            return true;
 595        }
 596
 597        let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
 598            return true;
 599        };
 600        let Some(new_edits) = self.prediction.interpolate(snapshot) else {
 601            return false;
 602        };
 603
 604        if old_edits.len() == 1 && new_edits.len() == 1 {
 605            let (old_range, old_text) = &old_edits[0];
 606            let (new_range, new_text) = &new_edits[0];
 607            new_range == old_range && new_text.starts_with(old_text)
 608        } else {
 609            true
 610        }
 611    }
 612}
 613
 614#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
 615pub struct EditPredictionId(Uuid);
 616
 617impl From<EditPredictionId> for gpui::ElementId {
 618    fn from(value: EditPredictionId) -> Self {
 619        gpui::ElementId::Uuid(value.0)
 620    }
 621}
 622
 623impl std::fmt::Display for EditPredictionId {
 624    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 625        write!(f, "{}", self.0)
 626    }
 627}
 628
 629#[derive(Clone)]
 630pub struct EditPrediction {
 631    id: EditPredictionId,
 632    edits: Arc<[(Range<Anchor>, String)]>,
 633    snapshot: BufferSnapshot,
 634    edit_preview: EditPreview,
 635}
 636
 637impl EditPrediction {
 638    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
 639        interpolate(&self.snapshot, new_snapshot, self.edits.clone())
 640    }
 641}
 642
 643struct PendingPrediction {
 644    id: usize,
 645    _task: Task<()>,
 646}
 647
 648impl EditPredictionProvider for ZetaEditPredictionProvider {
 649    fn name() -> &'static str {
 650        "zed-predict2"
 651    }
 652
 653    fn display_name() -> &'static str {
 654        "Zed's Edit Predictions 2"
 655    }
 656
 657    fn show_completions_in_menu() -> bool {
 658        true
 659    }
 660
 661    fn show_tab_accept_marker() -> bool {
 662        true
 663    }
 664
 665    fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
 666        // TODO [zeta2]
 667        DataCollectionState::Unsupported
 668    }
 669
 670    fn toggle_data_collection(&mut self, _cx: &mut App) {
 671        // TODO [zeta2]
 672    }
 673
 674    fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
 675        self.zeta.read(cx).usage(cx)
 676    }
 677
 678    fn is_enabled(
 679        &self,
 680        _buffer: &Entity<language::Buffer>,
 681        _cursor_position: language::Anchor,
 682        _cx: &App,
 683    ) -> bool {
 684        true
 685    }
 686
 687    fn is_refreshing(&self) -> bool {
 688        !self.pending_predictions.is_empty()
 689    }
 690
 691    fn refresh(
 692        &mut self,
 693        project: Option<Entity<project::Project>>,
 694        buffer: Entity<language::Buffer>,
 695        cursor_position: language::Anchor,
 696        _debounce: bool,
 697        cx: &mut Context<Self>,
 698    ) {
 699        let Some(project) = project else {
 700            return;
 701        };
 702
 703        if self
 704            .zeta
 705            .read(cx)
 706            .user_store
 707            .read_with(cx, |user_store, _cx| {
 708                user_store.account_too_young() || user_store.has_overdue_invoices()
 709            })
 710        {
 711            return;
 712        }
 713
 714        if let Some(current_prediction) = self.current_prediction.as_ref() {
 715            let snapshot = buffer.read(cx).snapshot();
 716            if current_prediction
 717                .prediction
 718                .interpolate(&snapshot)
 719                .is_some()
 720            {
 721                return;
 722            }
 723        }
 724
 725        let pending_prediction_id = self.next_pending_prediction_id;
 726        self.next_pending_prediction_id += 1;
 727        let last_request_timestamp = self.last_request_timestamp;
 728
 729        let task = cx.spawn(async move |this, cx| {
 730            if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
 731                .checked_duration_since(Instant::now())
 732            {
 733                cx.background_executor().timer(timeout).await;
 734            }
 735
 736            let prediction_request = this.update(cx, |this, cx| {
 737                this.last_request_timestamp = Instant::now();
 738                this.zeta.update(cx, |zeta, cx| {
 739                    zeta.request_prediction(&project, &buffer, cursor_position, cx)
 740                })
 741            });
 742
 743            let prediction = match prediction_request {
 744                Ok(prediction_request) => {
 745                    let prediction_request = prediction_request.await;
 746                    prediction_request.map(|c| {
 747                        c.map(|prediction| CurrentEditPrediction {
 748                            buffer_id: buffer.entity_id(),
 749                            prediction,
 750                        })
 751                    })
 752                }
 753                Err(error) => Err(error),
 754            };
 755
 756            this.update(cx, |this, cx| {
 757                if this.pending_predictions[0].id == pending_prediction_id {
 758                    this.pending_predictions.remove(0);
 759                } else {
 760                    this.pending_predictions.clear();
 761                }
 762
 763                let Some(new_prediction) = prediction
 764                    .context("edit prediction failed")
 765                    .log_err()
 766                    .flatten()
 767                else {
 768                    cx.notify();
 769                    return;
 770                };
 771
 772                if let Some(old_prediction) = this.current_prediction.as_ref() {
 773                    let snapshot = buffer.read(cx).snapshot();
 774                    if new_prediction.should_replace_prediction(old_prediction, &snapshot) {
 775                        this.current_prediction = Some(new_prediction);
 776                    }
 777                } else {
 778                    this.current_prediction = Some(new_prediction);
 779                }
 780
 781                cx.notify();
 782            })
 783            .ok();
 784        });
 785
 786        // We always maintain at most two pending predictions. When we already
 787        // have two, we replace the newest one.
 788        if self.pending_predictions.len() <= 1 {
 789            self.pending_predictions.push(PendingPrediction {
 790                id: pending_prediction_id,
 791                _task: task,
 792            });
 793        } else if self.pending_predictions.len() == 2 {
 794            self.pending_predictions.pop();
 795            self.pending_predictions.push(PendingPrediction {
 796                id: pending_prediction_id,
 797                _task: task,
 798            });
 799        }
 800
 801        cx.notify();
 802    }
 803
 804    fn cycle(
 805        &mut self,
 806        _buffer: Entity<language::Buffer>,
 807        _cursor_position: language::Anchor,
 808        _direction: Direction,
 809        _cx: &mut Context<Self>,
 810    ) {
 811    }
 812
 813    fn accept(&mut self, _cx: &mut Context<Self>) {
 814        // TODO [zeta2] report accept
 815        self.current_prediction.take();
 816        self.pending_predictions.clear();
 817    }
 818
 819    fn discard(&mut self, _cx: &mut Context<Self>) {
 820        self.pending_predictions.clear();
 821        self.current_prediction.take();
 822    }
 823
 824    fn suggest(
 825        &mut self,
 826        buffer: &Entity<language::Buffer>,
 827        cursor_position: language::Anchor,
 828        cx: &mut Context<Self>,
 829    ) -> Option<edit_prediction::EditPrediction> {
 830        let CurrentEditPrediction {
 831            buffer_id,
 832            prediction,
 833            ..
 834        } = self.current_prediction.as_mut()?;
 835
 836        // Invalidate previous prediction if it was generated for a different buffer.
 837        if *buffer_id != buffer.entity_id() {
 838            self.current_prediction.take();
 839            return None;
 840        }
 841
 842        let buffer = buffer.read(cx);
 843        let Some(edits) = prediction.interpolate(&buffer.snapshot()) else {
 844            self.current_prediction.take();
 845            return None;
 846        };
 847
 848        let cursor_row = cursor_position.to_point(buffer).row;
 849        let (closest_edit_ix, (closest_edit_range, _)) =
 850            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
 851                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
 852                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
 853                cmp::min(distance_from_start, distance_from_end)
 854            })?;
 855
 856        let mut edit_start_ix = closest_edit_ix;
 857        for (range, _) in edits[..edit_start_ix].iter().rev() {
 858            let distance_from_closest_edit =
 859                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
 860            if distance_from_closest_edit <= 1 {
 861                edit_start_ix -= 1;
 862            } else {
 863                break;
 864            }
 865        }
 866
 867        let mut edit_end_ix = closest_edit_ix + 1;
 868        for (range, _) in &edits[edit_end_ix..] {
 869            let distance_from_closest_edit =
 870                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
 871            if distance_from_closest_edit <= 1 {
 872                edit_end_ix += 1;
 873            } else {
 874                break;
 875            }
 876        }
 877
 878        Some(edit_prediction::EditPrediction {
 879            id: Some(prediction.id.to_string().into()),
 880            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
 881            edit_preview: Some(prediction.edit_preview.clone()),
 882        })
 883    }
 884}
 885
 886fn make_cloud_request(
 887    excerpt_path: PathBuf,
 888    context: EditPredictionContext,
 889    events: Vec<predict_edits_v3::Event>,
 890    can_collect_data: bool,
 891    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
 892    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
 893    debug_info: bool,
 894    worktrees: &Vec<worktree::Snapshot>,
 895    index_state: Option<&SyntaxIndexState>,
 896) -> predict_edits_v3::PredictEditsRequest {
 897    let mut signatures = Vec::new();
 898    let mut declaration_to_signature_index = HashMap::default();
 899    let mut referenced_declarations = Vec::new();
 900
 901    for snippet in context.snippets {
 902        let project_entry_id = snippet.declaration.project_entry_id();
 903        let Some(path) = worktrees.iter().find_map(|worktree| {
 904            worktree.entry_for_id(project_entry_id).map(|entry| {
 905                let mut full_path = PathBuf::new();
 906                full_path.push(worktree.root_name());
 907                full_path.push(&entry.path);
 908                full_path
 909            })
 910        }) else {
 911            continue;
 912        };
 913
 914        let parent_index = index_state.and_then(|index_state| {
 915            snippet.declaration.parent().and_then(|parent| {
 916                add_signature(
 917                    parent,
 918                    &mut declaration_to_signature_index,
 919                    &mut signatures,
 920                    index_state,
 921                )
 922            })
 923        });
 924
 925        let (text, text_is_truncated) = snippet.declaration.item_text();
 926        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
 927            path,
 928            text: text.into(),
 929            range: snippet.declaration.item_range(),
 930            text_is_truncated,
 931            signature_range: snippet.declaration.signature_range_in_item_text(),
 932            parent_index,
 933            score_components: snippet.score_components,
 934            signature_score: snippet.scores.signature,
 935            declaration_score: snippet.scores.declaration,
 936        });
 937    }
 938
 939    let excerpt_parent = index_state.and_then(|index_state| {
 940        context
 941            .excerpt
 942            .parent_declarations
 943            .last()
 944            .and_then(|(parent, _)| {
 945                add_signature(
 946                    *parent,
 947                    &mut declaration_to_signature_index,
 948                    &mut signatures,
 949                    index_state,
 950                )
 951            })
 952    });
 953
 954    predict_edits_v3::PredictEditsRequest {
 955        excerpt_path,
 956        excerpt: context.excerpt_text.body,
 957        excerpt_range: context.excerpt.range,
 958        cursor_offset: context.cursor_offset_in_excerpt,
 959        referenced_declarations,
 960        signatures,
 961        excerpt_parent,
 962        events,
 963        can_collect_data,
 964        diagnostic_groups,
 965        git_info,
 966        debug_info,
 967    }
 968}
 969
 970fn add_signature(
 971    declaration_id: DeclarationId,
 972    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
 973    signatures: &mut Vec<Signature>,
 974    index: &SyntaxIndexState,
 975) -> Option<usize> {
 976    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
 977        return Some(*signature_index);
 978    }
 979    let Some(parent_declaration) = index.declaration(declaration_id) else {
 980        log::error!("bug: missing parent declaration");
 981        return None;
 982    };
 983    let parent_index = parent_declaration.parent().and_then(|parent| {
 984        add_signature(parent, declaration_to_signature_index, signatures, index)
 985    });
 986    let (text, text_is_truncated) = parent_declaration.signature_text();
 987    let signature_index = signatures.len();
 988    signatures.push(Signature {
 989        text: text.into(),
 990        text_is_truncated,
 991        parent_index,
 992        range: parent_declaration.signature_range(),
 993    });
 994    declaration_to_signature_index.insert(declaration_id, signature_index);
 995    Some(signature_index)
 996}
 997
 998fn interpolate(
 999    old_snapshot: &BufferSnapshot,
1000    new_snapshot: &BufferSnapshot,
1001    current_edits: Arc<[(Range<Anchor>, String)]>,
1002) -> Option<Vec<(Range<Anchor>, String)>> {
1003    let mut edits = Vec::new();
1004
1005    let mut model_edits = current_edits.iter().peekable();
1006    for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
1007        while let Some((model_old_range, _)) = model_edits.peek() {
1008            let model_old_range = model_old_range.to_offset(old_snapshot);
1009            if model_old_range.end < user_edit.old.start {
1010                let (model_old_range, model_new_text) = model_edits.next().unwrap();
1011                edits.push((model_old_range.clone(), model_new_text.clone()));
1012            } else {
1013                break;
1014            }
1015        }
1016
1017        if let Some((model_old_range, model_new_text)) = model_edits.peek() {
1018            let model_old_offset_range = model_old_range.to_offset(old_snapshot);
1019            if user_edit.old == model_old_offset_range {
1020                let user_new_text = new_snapshot
1021                    .text_for_range(user_edit.new.clone())
1022                    .collect::<String>();
1023
1024                if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
1025                    if !model_suffix.is_empty() {
1026                        let anchor = old_snapshot.anchor_after(user_edit.old.end);
1027                        edits.push((anchor..anchor, model_suffix.to_string()));
1028                    }
1029
1030                    model_edits.next();
1031                    continue;
1032                }
1033            }
1034        }
1035
1036        return None;
1037    }
1038
1039    edits.extend(model_edits.cloned());
1040
1041    if edits.is_empty() { None } else { Some(edits) }
1042}
1043
1044#[cfg(test)]
1045mod tests {
1046    use super::*;
1047    use gpui::TestAppContext;
1048    use language::ToOffset as _;
1049
1050    #[gpui::test]
1051    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1052        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1053        let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1054            to_prediction_edits(
1055                [(2..5, "REM".to_string()), (9..11, "".to_string())],
1056                &buffer,
1057                cx,
1058            )
1059            .into()
1060        });
1061
1062        let edit_preview = cx
1063            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1064            .await;
1065
1066        let prediction = EditPrediction {
1067            id: EditPredictionId(Uuid::new_v4()),
1068            edits,
1069            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1070            edit_preview,
1071        };
1072
1073        cx.update(|cx| {
1074            assert_eq!(
1075                from_prediction_edits(
1076                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1077                    &buffer,
1078                    cx
1079                ),
1080                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1081            );
1082
1083            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1084            assert_eq!(
1085                from_prediction_edits(
1086                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1087                    &buffer,
1088                    cx
1089                ),
1090                vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1091            );
1092
1093            buffer.update(cx, |buffer, cx| buffer.undo(cx));
1094            assert_eq!(
1095                from_prediction_edits(
1096                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1097                    &buffer,
1098                    cx
1099                ),
1100                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1101            );
1102
1103            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1104            assert_eq!(
1105                from_prediction_edits(
1106                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1107                    &buffer,
1108                    cx
1109                ),
1110                vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1111            );
1112
1113            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1114            assert_eq!(
1115                from_prediction_edits(
1116                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1117                    &buffer,
1118                    cx
1119                ),
1120                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1121            );
1122
1123            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1124            assert_eq!(
1125                from_prediction_edits(
1126                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1127                    &buffer,
1128                    cx
1129                ),
1130                vec![(9..11, "".to_string())]
1131            );
1132
1133            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1134            assert_eq!(
1135                from_prediction_edits(
1136                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1137                    &buffer,
1138                    cx
1139                ),
1140                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1141            );
1142
1143            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1144            assert_eq!(
1145                from_prediction_edits(
1146                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1147                    &buffer,
1148                    cx
1149                ),
1150                vec![(4..4, "M".to_string())]
1151            );
1152
1153            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1154            assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1155        })
1156    }
1157
1158    fn to_prediction_edits(
1159        iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1160        buffer: &Entity<Buffer>,
1161        cx: &App,
1162    ) -> Vec<(Range<Anchor>, String)> {
1163        let buffer = buffer.read(cx);
1164        iterator
1165            .into_iter()
1166            .map(|(range, text)| {
1167                (
1168                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1169                    text,
1170                )
1171            })
1172            .collect()
1173    }
1174
1175    fn from_prediction_edits(
1176        editor_edits: &[(Range<Anchor>, String)],
1177        buffer: &Entity<Buffer>,
1178        cx: &App,
1179    ) -> Vec<(Range<usize>, String)> {
1180        let buffer = buffer.read(cx);
1181        editor_edits
1182            .iter()
1183            .map(|(range, text)| {
1184                (
1185                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
1186                    text.clone(),
1187                )
1188            })
1189            .collect()
1190    }
1191}