zeta2.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, Signature};
   5use cloud_llm_client::{
   6    EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
   7};
   8use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
   9use edit_prediction_context::{
  10    DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
  11    SyntaxIndexState,
  12};
  13use futures::AsyncReadExt as _;
  14use gpui::http_client::Method;
  15use gpui::{
  16    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, http_client,
  17    prelude::*,
  18};
  19use language::{Anchor, Buffer, OffsetRangeExt as _, ToPoint};
  20use language::{BufferSnapshot, EditPreview};
  21use language_model::{LlmApiToken, RefreshLlmTokenListener};
  22use project::Project;
  23use release_channel::AppVersion;
  24use std::cmp;
  25use std::collections::{HashMap, VecDeque, hash_map};
  26use std::path::PathBuf;
  27use std::str::FromStr as _;
  28use std::time::{Duration, Instant};
  29use std::{ops::Range, sync::Arc};
  30use thiserror::Error;
  31use util::ResultExt as _;
  32use uuid::Uuid;
  33use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  34
  35const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  36
  37/// Maximum number of events to track.
  38const MAX_EVENT_COUNT: usize = 16;
  39
  40#[derive(Clone)]
  41struct ZetaGlobal(Entity<Zeta>);
  42
  43impl Global for ZetaGlobal {}
  44
  45pub struct Zeta {
  46    client: Arc<Client>,
  47    user_store: Entity<UserStore>,
  48    llm_token: LlmApiToken,
  49    _llm_token_subscription: Subscription,
  50    projects: HashMap<EntityId, ZetaProject>,
  51    excerpt_options: EditPredictionExcerptOptions,
  52    update_required: bool,
  53}
  54
  55struct ZetaProject {
  56    syntax_index: Entity<SyntaxIndex>,
  57    events: VecDeque<Event>,
  58    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
  59}
  60
  61struct RegisteredBuffer {
  62    snapshot: BufferSnapshot,
  63    _subscriptions: [gpui::Subscription; 2],
  64}
  65
  66#[derive(Clone)]
  67pub enum Event {
  68    BufferChange {
  69        old_snapshot: BufferSnapshot,
  70        new_snapshot: BufferSnapshot,
  71        timestamp: Instant,
  72    },
  73}
  74
  75impl Zeta {
  76    pub fn global(
  77        client: &Arc<Client>,
  78        user_store: &Entity<UserStore>,
  79        cx: &mut App,
  80    ) -> Entity<Self> {
  81        cx.try_global::<ZetaGlobal>()
  82            .map(|global| global.0.clone())
  83            .unwrap_or_else(|| {
  84                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
  85                cx.set_global(ZetaGlobal(zeta.clone()));
  86                zeta
  87            })
  88    }
  89
  90    fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
  91        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
  92
  93        Self {
  94            projects: HashMap::new(),
  95            client,
  96            user_store,
  97            excerpt_options: EditPredictionExcerptOptions {
  98                max_bytes: 512,
  99                min_bytes: 128,
 100                target_before_cursor_over_total_bytes: 0.5,
 101            },
 102            llm_token: LlmApiToken::default(),
 103            _llm_token_subscription: cx.subscribe(
 104                &refresh_llm_token_listener,
 105                |this, _listener, _event, cx| {
 106                    let client = this.client.clone();
 107                    let llm_token = this.llm_token.clone();
 108                    cx.spawn(async move |_this, _cx| {
 109                        llm_token.refresh(&client).await?;
 110                        anyhow::Ok(())
 111                    })
 112                    .detach_and_log_err(cx);
 113                },
 114            ),
 115            update_required: false,
 116        }
 117    }
 118
 119    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 120        self.user_store.read(cx).edit_prediction_usage()
 121    }
 122
 123    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 124        self.get_or_init_zeta_project(project, cx);
 125    }
 126
 127    pub fn register_buffer(
 128        &mut self,
 129        buffer: &Entity<Buffer>,
 130        project: &Entity<Project>,
 131        cx: &mut Context<Self>,
 132    ) {
 133        let zeta_project = self.get_or_init_zeta_project(project, cx);
 134        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 135    }
 136
 137    fn get_or_init_zeta_project(
 138        &mut self,
 139        project: &Entity<Project>,
 140        cx: &mut App,
 141    ) -> &mut ZetaProject {
 142        self.projects
 143            .entry(project.entity_id())
 144            .or_insert_with(|| ZetaProject {
 145                syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
 146                events: VecDeque::new(),
 147                registered_buffers: HashMap::new(),
 148            })
 149    }
 150
 151    fn register_buffer_impl<'a>(
 152        zeta_project: &'a mut ZetaProject,
 153        buffer: &Entity<Buffer>,
 154        project: &Entity<Project>,
 155        cx: &mut Context<Self>,
 156    ) -> &'a mut RegisteredBuffer {
 157        let buffer_id = buffer.entity_id();
 158        match zeta_project.registered_buffers.entry(buffer_id) {
 159            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 160            hash_map::Entry::Vacant(entry) => {
 161                let snapshot = buffer.read(cx).snapshot();
 162                let project_entity_id = project.entity_id();
 163                entry.insert(RegisteredBuffer {
 164                    snapshot,
 165                    _subscriptions: [
 166                        cx.subscribe(buffer, {
 167                            let project = project.downgrade();
 168                            move |this, buffer, event, cx| {
 169                                if let language::BufferEvent::Edited = event
 170                                    && let Some(project) = project.upgrade()
 171                                {
 172                                    this.report_changes_for_buffer(&buffer, &project, cx);
 173                                }
 174                            }
 175                        }),
 176                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 177                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 178                            else {
 179                                return;
 180                            };
 181                            zeta_project.registered_buffers.remove(&buffer_id);
 182                        }),
 183                    ],
 184                })
 185            }
 186        }
 187    }
 188
 189    fn report_changes_for_buffer(
 190        &mut self,
 191        buffer: &Entity<Buffer>,
 192        project: &Entity<Project>,
 193        cx: &mut Context<Self>,
 194    ) -> BufferSnapshot {
 195        let zeta_project = self.get_or_init_zeta_project(project, cx);
 196        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 197
 198        let new_snapshot = buffer.read(cx).snapshot();
 199        if new_snapshot.version != registered_buffer.snapshot.version {
 200            let old_snapshot =
 201                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 202            Self::push_event(
 203                zeta_project,
 204                Event::BufferChange {
 205                    old_snapshot,
 206                    new_snapshot: new_snapshot.clone(),
 207                    timestamp: Instant::now(),
 208                },
 209            );
 210        }
 211
 212        new_snapshot
 213    }
 214
 215    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
 216        let events = &mut zeta_project.events;
 217
 218        if let Some(Event::BufferChange {
 219            new_snapshot: last_new_snapshot,
 220            timestamp: last_timestamp,
 221            ..
 222        }) = events.back_mut()
 223        {
 224            // Coalesce edits for the same buffer when they happen one after the other.
 225            let Event::BufferChange {
 226                old_snapshot,
 227                new_snapshot,
 228                timestamp,
 229            } = &event;
 230
 231            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 232                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 233                && old_snapshot.version == last_new_snapshot.version
 234            {
 235                *last_new_snapshot = new_snapshot.clone();
 236                *last_timestamp = *timestamp;
 237                return;
 238            }
 239        }
 240
 241        if events.len() >= MAX_EVENT_COUNT {
 242            // These are halved instead of popping to improve prompt caching.
 243            events.drain(..MAX_EVENT_COUNT / 2);
 244        }
 245
 246        events.push_back(event);
 247    }
 248
 249    pub fn request_prediction(
 250        &mut self,
 251        project: &Entity<Project>,
 252        buffer: &Entity<Buffer>,
 253        position: language::Anchor,
 254        cx: &mut Context<Self>,
 255    ) -> Task<Result<Option<EditPrediction>>> {
 256        let project_state = self.projects.get(&project.entity_id());
 257
 258        let index_state = project_state.map(|state| {
 259            state
 260                .syntax_index
 261                .read_with(cx, |index, _cx| index.state().clone())
 262        });
 263        let excerpt_options = self.excerpt_options.clone();
 264        let snapshot = buffer.read(cx).snapshot();
 265        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
 266            return Task::ready(Err(anyhow!("No file path for excerpt")));
 267        };
 268        let client = self.client.clone();
 269        let llm_token = self.llm_token.clone();
 270        let app_version = AppVersion::global(cx);
 271        let worktree_snapshots = project
 272            .read(cx)
 273            .worktrees(cx)
 274            .map(|worktree| worktree.read(cx).snapshot())
 275            .collect::<Vec<_>>();
 276
 277        let request_task = cx.background_spawn({
 278            let snapshot = snapshot.clone();
 279            async move {
 280                let index_state = if let Some(index_state) = index_state {
 281                    Some(index_state.lock_owned().await)
 282                } else {
 283                    None
 284                };
 285
 286                let cursor_point = position.to_point(&snapshot);
 287
 288                // TODO: make this only true if debug view is open
 289                let debug_info = true;
 290
 291                let Some(request) = EditPredictionContext::gather_context(
 292                    cursor_point,
 293                    &snapshot,
 294                    &excerpt_options,
 295                    index_state.as_deref(),
 296                )
 297                .map(|context| {
 298                    make_cloud_request(
 299                        excerpt_path.clone(),
 300                        context,
 301                        // TODO pass everything
 302                        Vec::new(),
 303                        false,
 304                        Vec::new(),
 305                        None,
 306                        debug_info,
 307                        &worktree_snapshots,
 308                        index_state.as_deref(),
 309                    )
 310                }) else {
 311                    return Ok(None);
 312                };
 313
 314                anyhow::Ok(Some(
 315                    Self::perform_request(client, llm_token, app_version, request).await?,
 316                ))
 317            }
 318        });
 319
 320        let buffer = buffer.clone();
 321
 322        cx.spawn(async move |this, cx| {
 323            match request_task.await {
 324                Ok(Some((response, usage))) => {
 325                    log::debug!("predicted edits: {:?}", &response.edits);
 326
 327                    if let Some(usage) = usage {
 328                        this.update(cx, |this, cx| {
 329                            this.user_store.update(cx, |user_store, cx| {
 330                                user_store.update_edit_prediction_usage(usage, cx);
 331                            });
 332                        })
 333                        .ok();
 334                    }
 335
 336                    // TODO telemetry: duration, etc
 337
 338                    // TODO produce smaller edits by diffing against snapshot first
 339                    //
 340                    // Cloud returns entire snippets/excerpts ranges as they were included
 341                    // in the request, but we should display smaller edits to the user.
 342                    //
 343                    // We can do this by computing a diff of each one against the snapshot.
 344                    // Similar to zeta::Zeta::compute_edits, but per edit.
 345                    let edits = response
 346                        .edits
 347                        .into_iter()
 348                        .map(|edit| {
 349                            // TODO edits to different files
 350                            (
 351                                snapshot.anchor_before(edit.range.start)
 352                                    ..snapshot.anchor_before(edit.range.end),
 353                                edit.content,
 354                            )
 355                        })
 356                        .collect::<Vec<_>>()
 357                        .into();
 358
 359                    let Some((edits, snapshot, edit_preview_task)) =
 360                        buffer.read_with(cx, |buffer, cx| {
 361                            let new_snapshot = buffer.snapshot();
 362                            let edits: Arc<[_]> =
 363                                interpolate(&snapshot, &new_snapshot, edits)?.into();
 364                            Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 365                        })?
 366                    else {
 367                        return Ok(None);
 368                    };
 369
 370                    Ok(Some(EditPrediction {
 371                        id: EditPredictionId(response.request_id),
 372                        edits,
 373                        snapshot,
 374                        edit_preview: edit_preview_task.await,
 375                    }))
 376                }
 377                Ok(None) => Ok(None),
 378                Err(err) => {
 379                    if err.is::<ZedUpdateRequiredError>() {
 380                        cx.update(|cx| {
 381                            this.update(cx, |this, _cx| {
 382                                this.update_required = true;
 383                            })
 384                            .ok();
 385
 386                            let error_message: SharedString = err.to_string().into();
 387                            show_app_notification(
 388                                NotificationId::unique::<ZedUpdateRequiredError>(),
 389                                cx,
 390                                move |cx| {
 391                                    cx.new(|cx| {
 392                                        ErrorMessagePrompt::new(error_message.clone(), cx)
 393                                            .with_link_button(
 394                                                "Update Zed",
 395                                                "https://zed.dev/releases",
 396                                            )
 397                                    })
 398                                },
 399                            );
 400                        })
 401                        .ok();
 402                    }
 403
 404                    Err(err)
 405                }
 406            }
 407        })
 408    }
 409
 410    async fn perform_request(
 411        client: Arc<Client>,
 412        llm_token: LlmApiToken,
 413        app_version: SemanticVersion,
 414        request: predict_edits_v3::PredictEditsRequest,
 415    ) -> Result<(
 416        predict_edits_v3::PredictEditsResponse,
 417        Option<EditPredictionUsage>,
 418    )> {
 419        let http_client = client.http_client();
 420        let mut token = llm_token.acquire(&client).await?;
 421        let mut did_retry = false;
 422
 423        loop {
 424            let request_builder = http_client::Request::builder().method(Method::POST);
 425            let request_builder =
 426                if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 427                    request_builder.uri(predict_edits_url)
 428                } else {
 429                    request_builder.uri(
 430                        http_client
 431                            .build_zed_llm_url("/predict_edits/v3", &[])?
 432                            .as_ref(),
 433                    )
 434                };
 435            let request = request_builder
 436                .header("Content-Type", "application/json")
 437                .header("Authorization", format!("Bearer {}", token))
 438                .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 439                .body(serde_json::to_string(&request)?.into())?;
 440
 441            let mut response = http_client.send(request).await?;
 442
 443            if let Some(minimum_required_version) = response
 444                .headers()
 445                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 446                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 447            {
 448                anyhow::ensure!(
 449                    app_version >= minimum_required_version,
 450                    ZedUpdateRequiredError {
 451                        minimum_version: minimum_required_version
 452                    }
 453                );
 454            }
 455
 456            if response.status().is_success() {
 457                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 458
 459                let mut body = Vec::new();
 460                response.body_mut().read_to_end(&mut body).await?;
 461                return Ok((serde_json::from_slice(&body)?, usage));
 462            } else if !did_retry
 463                && response
 464                    .headers()
 465                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 466                    .is_some()
 467            {
 468                did_retry = true;
 469                token = llm_token.refresh(&client).await?;
 470            } else {
 471                let mut body = String::new();
 472                response.body_mut().read_to_string(&mut body).await?;
 473                anyhow::bail!(
 474                    "error predicting edits.\nStatus: {:?}\nBody: {}",
 475                    response.status(),
 476                    body
 477                );
 478            }
 479        }
 480    }
 481}
 482
 483#[derive(Error, Debug)]
 484#[error(
 485    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
 486)]
 487pub struct ZedUpdateRequiredError {
 488    minimum_version: SemanticVersion,
 489}
 490
 491pub struct ZetaEditPredictionProvider {
 492    zeta: Entity<Zeta>,
 493    current_prediction: Option<CurrentEditPrediction>,
 494    next_pending_prediction_id: usize,
 495    pending_predictions: ArrayVec<PendingPrediction, 2>,
 496    last_request_timestamp: Instant,
 497}
 498
 499impl ZetaEditPredictionProvider {
 500    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
 501
 502    pub fn new(
 503        project: Option<&Entity<Project>>,
 504        client: &Arc<Client>,
 505        user_store: &Entity<UserStore>,
 506        cx: &mut App,
 507    ) -> Self {
 508        let zeta = Zeta::global(client, user_store, cx);
 509        if let Some(project) = project {
 510            zeta.update(cx, |zeta, cx| {
 511                zeta.register_project(project, cx);
 512            });
 513        }
 514
 515        Self {
 516            zeta,
 517            current_prediction: None,
 518            next_pending_prediction_id: 0,
 519            pending_predictions: ArrayVec::new(),
 520            last_request_timestamp: Instant::now(),
 521        }
 522    }
 523}
 524
 525#[derive(Clone)]
 526struct CurrentEditPrediction {
 527    buffer_id: EntityId,
 528    prediction: EditPrediction,
 529}
 530
 531impl CurrentEditPrediction {
 532    fn should_replace_prediction(&self, old_prediction: &Self, snapshot: &BufferSnapshot) -> bool {
 533        if self.buffer_id != old_prediction.buffer_id {
 534            return true;
 535        }
 536
 537        let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
 538            return true;
 539        };
 540        let Some(new_edits) = self.prediction.interpolate(snapshot) else {
 541            return false;
 542        };
 543
 544        if old_edits.len() == 1 && new_edits.len() == 1 {
 545            let (old_range, old_text) = &old_edits[0];
 546            let (new_range, new_text) = &new_edits[0];
 547            new_range == old_range && new_text.starts_with(old_text)
 548        } else {
 549            true
 550        }
 551    }
 552}
 553
 554#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
 555pub struct EditPredictionId(Uuid);
 556
 557impl From<EditPredictionId> for gpui::ElementId {
 558    fn from(value: EditPredictionId) -> Self {
 559        gpui::ElementId::Uuid(value.0)
 560    }
 561}
 562
 563impl std::fmt::Display for EditPredictionId {
 564    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 565        write!(f, "{}", self.0)
 566    }
 567}
 568
 569#[derive(Clone)]
 570pub struct EditPrediction {
 571    id: EditPredictionId,
 572    edits: Arc<[(Range<Anchor>, String)]>,
 573    snapshot: BufferSnapshot,
 574    edit_preview: EditPreview,
 575}
 576
 577impl EditPrediction {
 578    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
 579        interpolate(&self.snapshot, new_snapshot, self.edits.clone())
 580    }
 581}
 582
 583struct PendingPrediction {
 584    id: usize,
 585    _task: Task<()>,
 586}
 587
 588impl EditPredictionProvider for ZetaEditPredictionProvider {
 589    fn name() -> &'static str {
 590        "zed-predict2"
 591    }
 592
 593    fn display_name() -> &'static str {
 594        "Zed's Edit Predictions 2"
 595    }
 596
 597    fn show_completions_in_menu() -> bool {
 598        true
 599    }
 600
 601    fn show_tab_accept_marker() -> bool {
 602        true
 603    }
 604
 605    fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
 606        // TODO [zeta2]
 607        DataCollectionState::Unsupported
 608    }
 609
 610    fn toggle_data_collection(&mut self, _cx: &mut App) {
 611        // TODO [zeta2]
 612    }
 613
 614    fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
 615        self.zeta.read(cx).usage(cx)
 616    }
 617
 618    fn is_enabled(
 619        &self,
 620        _buffer: &Entity<language::Buffer>,
 621        _cursor_position: language::Anchor,
 622        _cx: &App,
 623    ) -> bool {
 624        true
 625    }
 626
 627    fn is_refreshing(&self) -> bool {
 628        !self.pending_predictions.is_empty()
 629    }
 630
 631    fn refresh(
 632        &mut self,
 633        project: Option<Entity<project::Project>>,
 634        buffer: Entity<language::Buffer>,
 635        cursor_position: language::Anchor,
 636        _debounce: bool,
 637        cx: &mut Context<Self>,
 638    ) {
 639        let Some(project) = project else {
 640            return;
 641        };
 642
 643        if self
 644            .zeta
 645            .read(cx)
 646            .user_store
 647            .read_with(cx, |user_store, _cx| {
 648                user_store.account_too_young() || user_store.has_overdue_invoices()
 649            })
 650        {
 651            return;
 652        }
 653
 654        if let Some(current_prediction) = self.current_prediction.as_ref() {
 655            let snapshot = buffer.read(cx).snapshot();
 656            if current_prediction
 657                .prediction
 658                .interpolate(&snapshot)
 659                .is_some()
 660            {
 661                return;
 662            }
 663        }
 664
 665        let pending_prediction_id = self.next_pending_prediction_id;
 666        self.next_pending_prediction_id += 1;
 667        let last_request_timestamp = self.last_request_timestamp;
 668
 669        let task = cx.spawn(async move |this, cx| {
 670            if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
 671                .checked_duration_since(Instant::now())
 672            {
 673                cx.background_executor().timer(timeout).await;
 674            }
 675
 676            let prediction_request = this.update(cx, |this, cx| {
 677                this.last_request_timestamp = Instant::now();
 678                this.zeta.update(cx, |zeta, cx| {
 679                    zeta.request_prediction(&project, &buffer, cursor_position, cx)
 680                })
 681            });
 682
 683            let prediction = match prediction_request {
 684                Ok(prediction_request) => {
 685                    let prediction_request = prediction_request.await;
 686                    prediction_request.map(|c| {
 687                        c.map(|prediction| CurrentEditPrediction {
 688                            buffer_id: buffer.entity_id(),
 689                            prediction,
 690                        })
 691                    })
 692                }
 693                Err(error) => Err(error),
 694            };
 695
 696            this.update(cx, |this, cx| {
 697                if this.pending_predictions[0].id == pending_prediction_id {
 698                    this.pending_predictions.remove(0);
 699                } else {
 700                    this.pending_predictions.clear();
 701                }
 702
 703                let Some(new_prediction) = prediction
 704                    .context("edit prediction failed")
 705                    .log_err()
 706                    .flatten()
 707                else {
 708                    cx.notify();
 709                    return;
 710                };
 711
 712                if let Some(old_prediction) = this.current_prediction.as_ref() {
 713                    let snapshot = buffer.read(cx).snapshot();
 714                    if new_prediction.should_replace_prediction(old_prediction, &snapshot) {
 715                        this.current_prediction = Some(new_prediction);
 716                    }
 717                } else {
 718                    this.current_prediction = Some(new_prediction);
 719                }
 720
 721                cx.notify();
 722            })
 723            .ok();
 724        });
 725
 726        // We always maintain at most two pending predictions. When we already
 727        // have two, we replace the newest one.
 728        if self.pending_predictions.len() <= 1 {
 729            self.pending_predictions.push(PendingPrediction {
 730                id: pending_prediction_id,
 731                _task: task,
 732            });
 733        } else if self.pending_predictions.len() == 2 {
 734            self.pending_predictions.pop();
 735            self.pending_predictions.push(PendingPrediction {
 736                id: pending_prediction_id,
 737                _task: task,
 738            });
 739        }
 740
 741        cx.notify();
 742    }
 743
 744    fn cycle(
 745        &mut self,
 746        _buffer: Entity<language::Buffer>,
 747        _cursor_position: language::Anchor,
 748        _direction: Direction,
 749        _cx: &mut Context<Self>,
 750    ) {
 751    }
 752
 753    fn accept(&mut self, _cx: &mut Context<Self>) {
 754        // TODO [zeta2] report accept
 755        self.current_prediction.take();
 756        self.pending_predictions.clear();
 757    }
 758
 759    fn discard(&mut self, _cx: &mut Context<Self>) {
 760        self.pending_predictions.clear();
 761        self.current_prediction.take();
 762    }
 763
 764    fn suggest(
 765        &mut self,
 766        buffer: &Entity<language::Buffer>,
 767        cursor_position: language::Anchor,
 768        cx: &mut Context<Self>,
 769    ) -> Option<edit_prediction::EditPrediction> {
 770        let CurrentEditPrediction {
 771            buffer_id,
 772            prediction,
 773            ..
 774        } = self.current_prediction.as_mut()?;
 775
 776        // Invalidate previous prediction if it was generated for a different buffer.
 777        if *buffer_id != buffer.entity_id() {
 778            self.current_prediction.take();
 779            return None;
 780        }
 781
 782        let buffer = buffer.read(cx);
 783        let Some(edits) = prediction.interpolate(&buffer.snapshot()) else {
 784            self.current_prediction.take();
 785            return None;
 786        };
 787
 788        let cursor_row = cursor_position.to_point(buffer).row;
 789        let (closest_edit_ix, (closest_edit_range, _)) =
 790            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
 791                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
 792                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
 793                cmp::min(distance_from_start, distance_from_end)
 794            })?;
 795
 796        let mut edit_start_ix = closest_edit_ix;
 797        for (range, _) in edits[..edit_start_ix].iter().rev() {
 798            let distance_from_closest_edit =
 799                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
 800            if distance_from_closest_edit <= 1 {
 801                edit_start_ix -= 1;
 802            } else {
 803                break;
 804            }
 805        }
 806
 807        let mut edit_end_ix = closest_edit_ix + 1;
 808        for (range, _) in &edits[edit_end_ix..] {
 809            let distance_from_closest_edit =
 810                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
 811            if distance_from_closest_edit <= 1 {
 812                edit_end_ix += 1;
 813            } else {
 814                break;
 815            }
 816        }
 817
 818        Some(edit_prediction::EditPrediction {
 819            id: Some(prediction.id.to_string().into()),
 820            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
 821            edit_preview: Some(prediction.edit_preview.clone()),
 822        })
 823    }
 824}
 825
 826fn make_cloud_request(
 827    excerpt_path: PathBuf,
 828    context: EditPredictionContext,
 829    events: Vec<predict_edits_v3::Event>,
 830    can_collect_data: bool,
 831    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
 832    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
 833    debug_info: bool,
 834    worktrees: &Vec<worktree::Snapshot>,
 835    index_state: Option<&SyntaxIndexState>,
 836) -> predict_edits_v3::PredictEditsRequest {
 837    let mut signatures = Vec::new();
 838    let mut declaration_to_signature_index = HashMap::default();
 839    let mut referenced_declarations = Vec::new();
 840
 841    for snippet in context.snippets {
 842        let project_entry_id = snippet.declaration.project_entry_id();
 843        // TODO: Use full paths (worktree rooted) - need to move full_path method to the snapshot.
 844        // Note that currently full_path is currently being used for excerpt_path.
 845        let Some(path) = worktrees.iter().find_map(|worktree| {
 846            let abs_path = worktree.abs_path();
 847            worktree
 848                .entry_for_id(project_entry_id)
 849                .map(|e| abs_path.join(&e.path))
 850        }) else {
 851            continue;
 852        };
 853
 854        let parent_index = index_state.and_then(|index_state| {
 855            snippet.declaration.parent().and_then(|parent| {
 856                add_signature(
 857                    parent,
 858                    &mut declaration_to_signature_index,
 859                    &mut signatures,
 860                    index_state,
 861                )
 862            })
 863        });
 864
 865        let (text, text_is_truncated) = snippet.declaration.item_text();
 866        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
 867            path,
 868            text: text.into(),
 869            range: snippet.declaration.item_range(),
 870            text_is_truncated,
 871            signature_range: snippet.declaration.signature_range_in_item_text(),
 872            parent_index,
 873            score_components: snippet.score_components,
 874            signature_score: snippet.scores.signature,
 875            declaration_score: snippet.scores.declaration,
 876        });
 877    }
 878
 879    let excerpt_parent = index_state.and_then(|index_state| {
 880        context
 881            .excerpt
 882            .parent_declarations
 883            .last()
 884            .and_then(|(parent, _)| {
 885                add_signature(
 886                    *parent,
 887                    &mut declaration_to_signature_index,
 888                    &mut signatures,
 889                    index_state,
 890                )
 891            })
 892    });
 893
 894    predict_edits_v3::PredictEditsRequest {
 895        excerpt_path,
 896        excerpt: context.excerpt_text.body,
 897        excerpt_range: context.excerpt.range,
 898        cursor_offset: context.cursor_offset_in_excerpt,
 899        referenced_declarations,
 900        signatures,
 901        excerpt_parent,
 902        events,
 903        can_collect_data,
 904        diagnostic_groups,
 905        git_info,
 906        debug_info,
 907    }
 908}
 909
 910fn add_signature(
 911    declaration_id: DeclarationId,
 912    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
 913    signatures: &mut Vec<Signature>,
 914    index: &SyntaxIndexState,
 915) -> Option<usize> {
 916    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
 917        return Some(*signature_index);
 918    }
 919    let Some(parent_declaration) = index.declaration(declaration_id) else {
 920        log::error!("bug: missing parent declaration");
 921        return None;
 922    };
 923    let parent_index = parent_declaration.parent().and_then(|parent| {
 924        add_signature(parent, declaration_to_signature_index, signatures, index)
 925    });
 926    let (text, text_is_truncated) = parent_declaration.signature_text();
 927    let signature_index = signatures.len();
 928    signatures.push(Signature {
 929        text: text.into(),
 930        text_is_truncated,
 931        parent_index,
 932    });
 933    declaration_to_signature_index.insert(declaration_id, signature_index);
 934    Some(signature_index)
 935}
 936
 937fn interpolate(
 938    old_snapshot: &BufferSnapshot,
 939    new_snapshot: &BufferSnapshot,
 940    current_edits: Arc<[(Range<Anchor>, String)]>,
 941) -> Option<Vec<(Range<Anchor>, String)>> {
 942    let mut edits = Vec::new();
 943
 944    let mut model_edits = current_edits.iter().peekable();
 945    for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
 946        while let Some((model_old_range, _)) = model_edits.peek() {
 947            let model_old_range = model_old_range.to_offset(old_snapshot);
 948            if model_old_range.end < user_edit.old.start {
 949                let (model_old_range, model_new_text) = model_edits.next().unwrap();
 950                edits.push((model_old_range.clone(), model_new_text.clone()));
 951            } else {
 952                break;
 953            }
 954        }
 955
 956        if let Some((model_old_range, model_new_text)) = model_edits.peek() {
 957            let model_old_offset_range = model_old_range.to_offset(old_snapshot);
 958            if user_edit.old == model_old_offset_range {
 959                let user_new_text = new_snapshot
 960                    .text_for_range(user_edit.new.clone())
 961                    .collect::<String>();
 962
 963                if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
 964                    if !model_suffix.is_empty() {
 965                        let anchor = old_snapshot.anchor_after(user_edit.old.end);
 966                        edits.push((anchor..anchor, model_suffix.to_string()));
 967                    }
 968
 969                    model_edits.next();
 970                    continue;
 971                }
 972            }
 973        }
 974
 975        return None;
 976    }
 977
 978    edits.extend(model_edits.cloned());
 979
 980    if edits.is_empty() { None } else { Some(edits) }
 981}
 982
 983#[cfg(test)]
 984mod tests {
 985    use super::*;
 986    use gpui::TestAppContext;
 987    use language::ToOffset as _;
 988
 989    #[gpui::test]
 990    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
 991        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
 992        let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
 993            to_prediction_edits(
 994                [(2..5, "REM".to_string()), (9..11, "".to_string())],
 995                &buffer,
 996                cx,
 997            )
 998            .into()
 999        });
1000
1001        let edit_preview = cx
1002            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1003            .await;
1004
1005        let prediction = EditPrediction {
1006            id: EditPredictionId(Uuid::new_v4()),
1007            edits,
1008            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1009            edit_preview,
1010        };
1011
1012        cx.update(|cx| {
1013            assert_eq!(
1014                from_prediction_edits(
1015                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1016                    &buffer,
1017                    cx
1018                ),
1019                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1020            );
1021
1022            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1023            assert_eq!(
1024                from_prediction_edits(
1025                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1026                    &buffer,
1027                    cx
1028                ),
1029                vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1030            );
1031
1032            buffer.update(cx, |buffer, cx| buffer.undo(cx));
1033            assert_eq!(
1034                from_prediction_edits(
1035                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1036                    &buffer,
1037                    cx
1038                ),
1039                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1040            );
1041
1042            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1043            assert_eq!(
1044                from_prediction_edits(
1045                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1046                    &buffer,
1047                    cx
1048                ),
1049                vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1050            );
1051
1052            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1053            assert_eq!(
1054                from_prediction_edits(
1055                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1056                    &buffer,
1057                    cx
1058                ),
1059                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1060            );
1061
1062            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1063            assert_eq!(
1064                from_prediction_edits(
1065                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1066                    &buffer,
1067                    cx
1068                ),
1069                vec![(9..11, "".to_string())]
1070            );
1071
1072            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1073            assert_eq!(
1074                from_prediction_edits(
1075                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1076                    &buffer,
1077                    cx
1078                ),
1079                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1080            );
1081
1082            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1083            assert_eq!(
1084                from_prediction_edits(
1085                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1086                    &buffer,
1087                    cx
1088                ),
1089                vec![(4..4, "M".to_string())]
1090            );
1091
1092            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1093            assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1094        })
1095    }
1096
1097    fn to_prediction_edits(
1098        iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1099        buffer: &Entity<Buffer>,
1100        cx: &App,
1101    ) -> Vec<(Range<Anchor>, String)> {
1102        let buffer = buffer.read(cx);
1103        iterator
1104            .into_iter()
1105            .map(|(range, text)| {
1106                (
1107                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1108                    text,
1109                )
1110            })
1111            .collect()
1112    }
1113
1114    fn from_prediction_edits(
1115        editor_edits: &[(Range<Anchor>, String)],
1116        buffer: &Entity<Buffer>,
1117        cx: &App,
1118    ) -> Vec<(Range<usize>, String)> {
1119        let buffer = buffer.read(cx);
1120        editor_edits
1121            .iter()
1122            .map(|(range, text)| {
1123                (
1124                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
1125                    text.clone(),
1126                )
1127            })
1128            .collect()
1129    }
1130}