zeta2.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use arrayvec::ArrayVec;
   3use client::{Client, EditPredictionUsage, UserStore};
   4use cloud_llm_client::predict_edits_v3::{self, Signature};
   5use cloud_llm_client::{
   6    EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
   7};
   8use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
   9use edit_prediction_context::{
  10    DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
  11    SyntaxIndexState,
  12};
  13use futures::AsyncReadExt as _;
  14use gpui::http_client::Method;
  15use gpui::{
  16    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, http_client,
  17    prelude::*,
  18};
  19use language::{Anchor, Buffer, OffsetRangeExt as _, ToPoint};
  20use language::{BufferSnapshot, EditPreview};
  21use language_model::{LlmApiToken, RefreshLlmTokenListener};
  22use project::Project;
  23use release_channel::AppVersion;
  24use std::cmp;
  25use std::collections::{HashMap, VecDeque, hash_map};
  26use std::fmt::Write;
  27use std::path::{Path, PathBuf};
  28use std::str::FromStr as _;
  29use std::time::{Duration, Instant};
  30use std::{ops::Range, sync::Arc};
  31use thiserror::Error;
  32use util::ResultExt as _;
  33use uuid::Uuid;
  34use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  35
  36const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  37
  38/// Maximum number of events to track.
  39const MAX_EVENT_COUNT: usize = 16;
  40
  41#[derive(Clone)]
  42struct ZetaGlobal(Entity<Zeta>);
  43
  44impl Global for ZetaGlobal {}
  45
  46pub struct Zeta {
  47    client: Arc<Client>,
  48    user_store: Entity<UserStore>,
  49    llm_token: LlmApiToken,
  50    _llm_token_subscription: Subscription,
  51    projects: HashMap<EntityId, ZetaProject>,
  52    excerpt_options: EditPredictionExcerptOptions,
  53    update_required: bool,
  54}
  55
  56struct ZetaProject {
  57    syntax_index: Entity<SyntaxIndex>,
  58    events: VecDeque<Event>,
  59    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
  60}
  61
  62struct RegisteredBuffer {
  63    snapshot: BufferSnapshot,
  64    _subscriptions: [gpui::Subscription; 2],
  65}
  66
  67#[derive(Clone)]
  68pub enum Event {
  69    BufferChange {
  70        old_snapshot: BufferSnapshot,
  71        new_snapshot: BufferSnapshot,
  72        timestamp: Instant,
  73    },
  74}
  75
  76impl Event {
  77    //TODO: Actually use the events this in the prompt
  78    fn to_prompt(&self) -> String {
  79        match self {
  80            Event::BufferChange {
  81                old_snapshot,
  82                new_snapshot,
  83                ..
  84            } => {
  85                let mut prompt = String::new();
  86
  87                let old_path = old_snapshot
  88                    .file()
  89                    .map(|f| f.path().as_ref())
  90                    .unwrap_or(Path::new("untitled"));
  91                let new_path = new_snapshot
  92                    .file()
  93                    .map(|f| f.path().as_ref())
  94                    .unwrap_or(Path::new("untitled"));
  95                if old_path != new_path {
  96                    writeln!(prompt, "User renamed {:?} to {:?}\n", old_path, new_path).unwrap();
  97                }
  98
  99                let diff = language::unified_diff(&old_snapshot.text(), &new_snapshot.text());
 100                if !diff.is_empty() {
 101                    write!(
 102                        prompt,
 103                        "User edited {:?}:\n```diff\n{}\n```",
 104                        new_path, diff
 105                    )
 106                    .unwrap();
 107                }
 108
 109                prompt
 110            }
 111        }
 112    }
 113}
 114
 115impl Zeta {
 116    pub fn global(
 117        client: &Arc<Client>,
 118        user_store: &Entity<UserStore>,
 119        cx: &mut App,
 120    ) -> Entity<Self> {
 121        cx.try_global::<ZetaGlobal>()
 122            .map(|global| global.0.clone())
 123            .unwrap_or_else(|| {
 124                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 125                cx.set_global(ZetaGlobal(zeta.clone()));
 126                zeta
 127            })
 128    }
 129
 130    fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 131        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 132
 133        Self {
 134            projects: HashMap::new(),
 135            client,
 136            user_store,
 137            excerpt_options: EditPredictionExcerptOptions {
 138                max_bytes: 512,
 139                min_bytes: 128,
 140                target_before_cursor_over_total_bytes: 0.5,
 141            },
 142            llm_token: LlmApiToken::default(),
 143            _llm_token_subscription: cx.subscribe(
 144                &refresh_llm_token_listener,
 145                |this, _listener, _event, cx| {
 146                    let client = this.client.clone();
 147                    let llm_token = this.llm_token.clone();
 148                    cx.spawn(async move |_this, _cx| {
 149                        llm_token.refresh(&client).await?;
 150                        anyhow::Ok(())
 151                    })
 152                    .detach_and_log_err(cx);
 153                },
 154            ),
 155            update_required: false,
 156        }
 157    }
 158
 159    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 160        self.user_store.read(cx).edit_prediction_usage()
 161    }
 162
 163    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 164        self.get_or_init_zeta_project(project, cx);
 165    }
 166
 167    pub fn register_buffer(
 168        &mut self,
 169        buffer: &Entity<Buffer>,
 170        project: &Entity<Project>,
 171        cx: &mut Context<Self>,
 172    ) {
 173        let zeta_project = self.get_or_init_zeta_project(project, cx);
 174        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 175    }
 176
 177    fn get_or_init_zeta_project(
 178        &mut self,
 179        project: &Entity<Project>,
 180        cx: &mut App,
 181    ) -> &mut ZetaProject {
 182        self.projects
 183            .entry(project.entity_id())
 184            .or_insert_with(|| ZetaProject {
 185                syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
 186                events: VecDeque::new(),
 187                registered_buffers: HashMap::new(),
 188            })
 189    }
 190
 191    fn register_buffer_impl<'a>(
 192        zeta_project: &'a mut ZetaProject,
 193        buffer: &Entity<Buffer>,
 194        project: &Entity<Project>,
 195        cx: &mut Context<Self>,
 196    ) -> &'a mut RegisteredBuffer {
 197        let buffer_id = buffer.entity_id();
 198        match zeta_project.registered_buffers.entry(buffer_id) {
 199            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 200            hash_map::Entry::Vacant(entry) => {
 201                let snapshot = buffer.read(cx).snapshot();
 202                let project_entity_id = project.entity_id();
 203                entry.insert(RegisteredBuffer {
 204                    snapshot,
 205                    _subscriptions: [
 206                        cx.subscribe(buffer, {
 207                            let project = project.downgrade();
 208                            move |this, buffer, event, cx| {
 209                                if let language::BufferEvent::Edited = event
 210                                    && let Some(project) = project.upgrade()
 211                                {
 212                                    this.report_changes_for_buffer(&buffer, &project, cx);
 213                                }
 214                            }
 215                        }),
 216                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 217                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 218                            else {
 219                                return;
 220                            };
 221                            zeta_project.registered_buffers.remove(&buffer_id);
 222                        }),
 223                    ],
 224                })
 225            }
 226        }
 227    }
 228
 229    fn report_changes_for_buffer(
 230        &mut self,
 231        buffer: &Entity<Buffer>,
 232        project: &Entity<Project>,
 233        cx: &mut Context<Self>,
 234    ) -> BufferSnapshot {
 235        let zeta_project = self.get_or_init_zeta_project(project, cx);
 236        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 237
 238        let new_snapshot = buffer.read(cx).snapshot();
 239        if new_snapshot.version != registered_buffer.snapshot.version {
 240            let old_snapshot =
 241                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 242            Self::push_event(
 243                zeta_project,
 244                Event::BufferChange {
 245                    old_snapshot,
 246                    new_snapshot: new_snapshot.clone(),
 247                    timestamp: Instant::now(),
 248                },
 249            );
 250        }
 251
 252        new_snapshot
 253    }
 254
 255    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
 256        let events = &mut zeta_project.events;
 257
 258        if let Some(Event::BufferChange {
 259            new_snapshot: last_new_snapshot,
 260            timestamp: last_timestamp,
 261            ..
 262        }) = events.back_mut()
 263        {
 264            // Coalesce edits for the same buffer when they happen one after the other.
 265            let Event::BufferChange {
 266                old_snapshot,
 267                new_snapshot,
 268                timestamp,
 269            } = &event;
 270
 271            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 272                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 273                && old_snapshot.version == last_new_snapshot.version
 274            {
 275                *last_new_snapshot = new_snapshot.clone();
 276                *last_timestamp = *timestamp;
 277                return;
 278            }
 279        }
 280
 281        if events.len() >= MAX_EVENT_COUNT {
 282            // These are halved instead of popping to improve prompt caching.
 283            events.drain(..MAX_EVENT_COUNT / 2);
 284        }
 285
 286        events.push_back(event);
 287    }
 288
 289    pub fn request_prediction(
 290        &mut self,
 291        project: &Entity<Project>,
 292        buffer: &Entity<Buffer>,
 293        position: language::Anchor,
 294        cx: &mut Context<Self>,
 295    ) -> Task<Result<Option<EditPrediction>>> {
 296        let project_state = self.projects.get(&project.entity_id());
 297
 298        let index_state = project_state.map(|state| {
 299            state
 300                .syntax_index
 301                .read_with(cx, |index, _cx| index.state().clone())
 302        });
 303        let excerpt_options = self.excerpt_options.clone();
 304        let snapshot = buffer.read(cx).snapshot();
 305        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
 306            return Task::ready(Err(anyhow!("No file path for excerpt")));
 307        };
 308        let client = self.client.clone();
 309        let llm_token = self.llm_token.clone();
 310        let app_version = AppVersion::global(cx);
 311        let worktree_snapshots = project
 312            .read(cx)
 313            .worktrees(cx)
 314            .map(|worktree| worktree.read(cx).snapshot())
 315            .collect::<Vec<_>>();
 316
 317        let request_task = cx.background_spawn({
 318            let snapshot = snapshot.clone();
 319            async move {
 320                let index_state = if let Some(index_state) = index_state {
 321                    Some(index_state.lock_owned().await)
 322                } else {
 323                    None
 324                };
 325
 326                let cursor_point = position.to_point(&snapshot);
 327
 328                // TODO: make this only true if debug view is open
 329                let debug_info = true;
 330
 331                let Some(request) = EditPredictionContext::gather_context(
 332                    cursor_point,
 333                    &snapshot,
 334                    &excerpt_options,
 335                    index_state.as_deref(),
 336                )
 337                .map(|context| {
 338                    make_cloud_request(
 339                        excerpt_path.clone(),
 340                        context,
 341                        // TODO pass everything
 342                        Vec::new(),
 343                        false,
 344                        Vec::new(),
 345                        None,
 346                        debug_info,
 347                        &worktree_snapshots,
 348                        index_state.as_deref(),
 349                    )
 350                }) else {
 351                    return Ok(None);
 352                };
 353
 354                anyhow::Ok(Some(
 355                    Self::perform_request(client, llm_token, app_version, request).await?,
 356                ))
 357            }
 358        });
 359
 360        let buffer = buffer.clone();
 361
 362        cx.spawn(async move |this, cx| {
 363            match request_task.await {
 364                Ok(Some((response, usage))) => {
 365                    log::debug!("predicted edits: {:?}", &response.edits);
 366
 367                    if let Some(usage) = usage {
 368                        this.update(cx, |this, cx| {
 369                            this.user_store.update(cx, |user_store, cx| {
 370                                user_store.update_edit_prediction_usage(usage, cx);
 371                            });
 372                        })
 373                        .ok();
 374                    }
 375
 376                    // TODO telemetry: duration, etc
 377
 378                    // TODO produce smaller edits by diffing against snapshot first
 379                    //
 380                    // Cloud returns entire snippets/excerpts ranges as they were included
 381                    // in the request, but we should display smaller edits to the user.
 382                    //
 383                    // We can do this by computing a diff of each one against the snapshot.
 384                    // Similar to zeta::Zeta::compute_edits, but per edit.
 385                    let edits = response
 386                        .edits
 387                        .into_iter()
 388                        .map(|edit| {
 389                            // TODO edits to different files
 390                            (
 391                                snapshot.anchor_before(edit.range.start)
 392                                    ..snapshot.anchor_before(edit.range.end),
 393                                edit.content,
 394                            )
 395                        })
 396                        .collect::<Vec<_>>()
 397                        .into();
 398
 399                    let Some((edits, snapshot, edit_preview_task)) =
 400                        buffer.read_with(cx, |buffer, cx| {
 401                            let new_snapshot = buffer.snapshot();
 402                            let edits: Arc<[_]> =
 403                                interpolate(&snapshot, &new_snapshot, edits)?.into();
 404                            Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 405                        })?
 406                    else {
 407                        return Ok(None);
 408                    };
 409
 410                    Ok(Some(EditPrediction {
 411                        id: EditPredictionId(response.request_id),
 412                        edits,
 413                        snapshot,
 414                        edit_preview: edit_preview_task.await,
 415                    }))
 416                }
 417                Ok(None) => Ok(None),
 418                Err(err) => {
 419                    if err.is::<ZedUpdateRequiredError>() {
 420                        cx.update(|cx| {
 421                            this.update(cx, |this, _cx| {
 422                                this.update_required = true;
 423                            })
 424                            .ok();
 425
 426                            let error_message: SharedString = err.to_string().into();
 427                            show_app_notification(
 428                                NotificationId::unique::<ZedUpdateRequiredError>(),
 429                                cx,
 430                                move |cx| {
 431                                    cx.new(|cx| {
 432                                        ErrorMessagePrompt::new(error_message.clone(), cx)
 433                                            .with_link_button(
 434                                                "Update Zed",
 435                                                "https://zed.dev/releases",
 436                                            )
 437                                    })
 438                                },
 439                            );
 440                        })
 441                        .ok();
 442                    }
 443
 444                    Err(err)
 445                }
 446            }
 447        })
 448    }
 449
 450    async fn perform_request(
 451        client: Arc<Client>,
 452        llm_token: LlmApiToken,
 453        app_version: SemanticVersion,
 454        request: predict_edits_v3::PredictEditsRequest,
 455    ) -> Result<(
 456        predict_edits_v3::PredictEditsResponse,
 457        Option<EditPredictionUsage>,
 458    )> {
 459        let http_client = client.http_client();
 460        let mut token = llm_token.acquire(&client).await?;
 461        let mut did_retry = false;
 462
 463        loop {
 464            let request_builder = http_client::Request::builder().method(Method::POST);
 465            let request_builder =
 466                if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 467                    request_builder.uri(predict_edits_url)
 468                } else {
 469                    request_builder.uri(
 470                        http_client
 471                            .build_zed_llm_url("/predict_edits/v3", &[])?
 472                            .as_ref(),
 473                    )
 474                };
 475            let request = request_builder
 476                .header("Content-Type", "application/json")
 477                .header("Authorization", format!("Bearer {}", token))
 478                .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 479                .body(serde_json::to_string(&request)?.into())?;
 480
 481            let mut response = http_client.send(request).await?;
 482
 483            if let Some(minimum_required_version) = response
 484                .headers()
 485                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 486                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 487            {
 488                anyhow::ensure!(
 489                    app_version >= minimum_required_version,
 490                    ZedUpdateRequiredError {
 491                        minimum_version: minimum_required_version
 492                    }
 493                );
 494            }
 495
 496            if response.status().is_success() {
 497                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 498
 499                let mut body = Vec::new();
 500                response.body_mut().read_to_end(&mut body).await?;
 501                return Ok((serde_json::from_slice(&body)?, usage));
 502            } else if !did_retry
 503                && response
 504                    .headers()
 505                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 506                    .is_some()
 507            {
 508                did_retry = true;
 509                token = llm_token.refresh(&client).await?;
 510            } else {
 511                let mut body = String::new();
 512                response.body_mut().read_to_string(&mut body).await?;
 513                anyhow::bail!(
 514                    "error predicting edits.\nStatus: {:?}\nBody: {}",
 515                    response.status(),
 516                    body
 517                );
 518            }
 519        }
 520    }
 521}
 522
 523#[derive(Error, Debug)]
 524#[error(
 525    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
 526)]
 527pub struct ZedUpdateRequiredError {
 528    minimum_version: SemanticVersion,
 529}
 530
 531pub struct ZetaEditPredictionProvider {
 532    zeta: Entity<Zeta>,
 533    current_prediction: Option<CurrentEditPrediction>,
 534    next_pending_prediction_id: usize,
 535    pending_predictions: ArrayVec<PendingPrediction, 2>,
 536    last_request_timestamp: Instant,
 537}
 538
 539impl ZetaEditPredictionProvider {
 540    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
 541
 542    pub fn new(
 543        project: Option<&Entity<Project>>,
 544        client: &Arc<Client>,
 545        user_store: &Entity<UserStore>,
 546        cx: &mut App,
 547    ) -> Self {
 548        let zeta = Zeta::global(client, user_store, cx);
 549        if let Some(project) = project {
 550            zeta.update(cx, |zeta, cx| {
 551                zeta.register_project(project, cx);
 552            });
 553        }
 554
 555        Self {
 556            zeta,
 557            current_prediction: None,
 558            next_pending_prediction_id: 0,
 559            pending_predictions: ArrayVec::new(),
 560            last_request_timestamp: Instant::now(),
 561        }
 562    }
 563}
 564
 565#[derive(Clone)]
 566struct CurrentEditPrediction {
 567    buffer_id: EntityId,
 568    prediction: EditPrediction,
 569}
 570
 571impl CurrentEditPrediction {
 572    fn should_replace_prediction(&self, old_prediction: &Self, snapshot: &BufferSnapshot) -> bool {
 573        if self.buffer_id != old_prediction.buffer_id {
 574            return true;
 575        }
 576
 577        let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
 578            return true;
 579        };
 580        let Some(new_edits) = self.prediction.interpolate(snapshot) else {
 581            return false;
 582        };
 583
 584        if old_edits.len() == 1 && new_edits.len() == 1 {
 585            let (old_range, old_text) = &old_edits[0];
 586            let (new_range, new_text) = &new_edits[0];
 587            new_range == old_range && new_text.starts_with(old_text)
 588        } else {
 589            true
 590        }
 591    }
 592}
 593
 594#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
 595pub struct EditPredictionId(Uuid);
 596
 597impl From<EditPredictionId> for gpui::ElementId {
 598    fn from(value: EditPredictionId) -> Self {
 599        gpui::ElementId::Uuid(value.0)
 600    }
 601}
 602
 603impl std::fmt::Display for EditPredictionId {
 604    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 605        write!(f, "{}", self.0)
 606    }
 607}
 608
 609#[derive(Clone)]
 610pub struct EditPrediction {
 611    id: EditPredictionId,
 612    edits: Arc<[(Range<Anchor>, String)]>,
 613    snapshot: BufferSnapshot,
 614    edit_preview: EditPreview,
 615}
 616
 617impl EditPrediction {
 618    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
 619        interpolate(&self.snapshot, new_snapshot, self.edits.clone())
 620    }
 621}
 622
 623struct PendingPrediction {
 624    id: usize,
 625    _task: Task<()>,
 626}
 627
 628impl EditPredictionProvider for ZetaEditPredictionProvider {
 629    fn name() -> &'static str {
 630        "zed-predict2"
 631    }
 632
 633    fn display_name() -> &'static str {
 634        "Zed's Edit Predictions 2"
 635    }
 636
 637    fn show_completions_in_menu() -> bool {
 638        true
 639    }
 640
 641    fn show_tab_accept_marker() -> bool {
 642        true
 643    }
 644
 645    fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
 646        // TODO [zeta2]
 647        DataCollectionState::Unsupported
 648    }
 649
 650    fn toggle_data_collection(&mut self, _cx: &mut App) {
 651        // TODO [zeta2]
 652    }
 653
 654    fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
 655        self.zeta.read(cx).usage(cx)
 656    }
 657
 658    fn is_enabled(
 659        &self,
 660        _buffer: &Entity<language::Buffer>,
 661        _cursor_position: language::Anchor,
 662        _cx: &App,
 663    ) -> bool {
 664        true
 665    }
 666
 667    fn is_refreshing(&self) -> bool {
 668        !self.pending_predictions.is_empty()
 669    }
 670
 671    fn refresh(
 672        &mut self,
 673        project: Option<Entity<project::Project>>,
 674        buffer: Entity<language::Buffer>,
 675        cursor_position: language::Anchor,
 676        _debounce: bool,
 677        cx: &mut Context<Self>,
 678    ) {
 679        let Some(project) = project else {
 680            return;
 681        };
 682
 683        if self
 684            .zeta
 685            .read(cx)
 686            .user_store
 687            .read_with(cx, |user_store, _cx| {
 688                user_store.account_too_young() || user_store.has_overdue_invoices()
 689            })
 690        {
 691            return;
 692        }
 693
 694        if let Some(current_prediction) = self.current_prediction.as_ref() {
 695            let snapshot = buffer.read(cx).snapshot();
 696            if current_prediction
 697                .prediction
 698                .interpolate(&snapshot)
 699                .is_some()
 700            {
 701                return;
 702            }
 703        }
 704
 705        let pending_prediction_id = self.next_pending_prediction_id;
 706        self.next_pending_prediction_id += 1;
 707        let last_request_timestamp = self.last_request_timestamp;
 708
 709        let task = cx.spawn(async move |this, cx| {
 710            if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
 711                .checked_duration_since(Instant::now())
 712            {
 713                cx.background_executor().timer(timeout).await;
 714            }
 715
 716            let prediction_request = this.update(cx, |this, cx| {
 717                this.last_request_timestamp = Instant::now();
 718                this.zeta.update(cx, |zeta, cx| {
 719                    zeta.request_prediction(&project, &buffer, cursor_position, cx)
 720                })
 721            });
 722
 723            let prediction = match prediction_request {
 724                Ok(prediction_request) => {
 725                    let prediction_request = prediction_request.await;
 726                    prediction_request.map(|c| {
 727                        c.map(|prediction| CurrentEditPrediction {
 728                            buffer_id: buffer.entity_id(),
 729                            prediction,
 730                        })
 731                    })
 732                }
 733                Err(error) => Err(error),
 734            };
 735
 736            this.update(cx, |this, cx| {
 737                if this.pending_predictions[0].id == pending_prediction_id {
 738                    this.pending_predictions.remove(0);
 739                } else {
 740                    this.pending_predictions.clear();
 741                }
 742
 743                let Some(new_prediction) = prediction
 744                    .context("edit prediction failed")
 745                    .log_err()
 746                    .flatten()
 747                else {
 748                    cx.notify();
 749                    return;
 750                };
 751
 752                if let Some(old_prediction) = this.current_prediction.as_ref() {
 753                    let snapshot = buffer.read(cx).snapshot();
 754                    if new_prediction.should_replace_prediction(old_prediction, &snapshot) {
 755                        this.current_prediction = Some(new_prediction);
 756                    }
 757                } else {
 758                    this.current_prediction = Some(new_prediction);
 759                }
 760
 761                cx.notify();
 762            })
 763            .ok();
 764        });
 765
 766        // We always maintain at most two pending predictions. When we already
 767        // have two, we replace the newest one.
 768        if self.pending_predictions.len() <= 1 {
 769            self.pending_predictions.push(PendingPrediction {
 770                id: pending_prediction_id,
 771                _task: task,
 772            });
 773        } else if self.pending_predictions.len() == 2 {
 774            self.pending_predictions.pop();
 775            self.pending_predictions.push(PendingPrediction {
 776                id: pending_prediction_id,
 777                _task: task,
 778            });
 779        }
 780
 781        cx.notify();
 782    }
 783
 784    fn cycle(
 785        &mut self,
 786        _buffer: Entity<language::Buffer>,
 787        _cursor_position: language::Anchor,
 788        _direction: Direction,
 789        _cx: &mut Context<Self>,
 790    ) {
 791    }
 792
 793    fn accept(&mut self, _cx: &mut Context<Self>) {
 794        // TODO [zeta2] report accept
 795        self.current_prediction.take();
 796        self.pending_predictions.clear();
 797    }
 798
 799    fn discard(&mut self, _cx: &mut Context<Self>) {
 800        self.pending_predictions.clear();
 801        self.current_prediction.take();
 802    }
 803
 804    fn suggest(
 805        &mut self,
 806        buffer: &Entity<language::Buffer>,
 807        cursor_position: language::Anchor,
 808        cx: &mut Context<Self>,
 809    ) -> Option<edit_prediction::EditPrediction> {
 810        let CurrentEditPrediction {
 811            buffer_id,
 812            prediction,
 813            ..
 814        } = self.current_prediction.as_mut()?;
 815
 816        // Invalidate previous prediction if it was generated for a different buffer.
 817        if *buffer_id != buffer.entity_id() {
 818            self.current_prediction.take();
 819            return None;
 820        }
 821
 822        let buffer = buffer.read(cx);
 823        let Some(edits) = prediction.interpolate(&buffer.snapshot()) else {
 824            self.current_prediction.take();
 825            return None;
 826        };
 827
 828        let cursor_row = cursor_position.to_point(buffer).row;
 829        let (closest_edit_ix, (closest_edit_range, _)) =
 830            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
 831                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
 832                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
 833                cmp::min(distance_from_start, distance_from_end)
 834            })?;
 835
 836        let mut edit_start_ix = closest_edit_ix;
 837        for (range, _) in edits[..edit_start_ix].iter().rev() {
 838            let distance_from_closest_edit =
 839                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
 840            if distance_from_closest_edit <= 1 {
 841                edit_start_ix -= 1;
 842            } else {
 843                break;
 844            }
 845        }
 846
 847        let mut edit_end_ix = closest_edit_ix + 1;
 848        for (range, _) in &edits[edit_end_ix..] {
 849            let distance_from_closest_edit =
 850                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
 851            if distance_from_closest_edit <= 1 {
 852                edit_end_ix += 1;
 853            } else {
 854                break;
 855            }
 856        }
 857
 858        Some(edit_prediction::EditPrediction {
 859            id: Some(prediction.id.to_string().into()),
 860            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
 861            edit_preview: Some(prediction.edit_preview.clone()),
 862        })
 863    }
 864}
 865
 866fn make_cloud_request(
 867    excerpt_path: PathBuf,
 868    context: EditPredictionContext,
 869    events: Vec<predict_edits_v3::Event>,
 870    can_collect_data: bool,
 871    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
 872    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
 873    debug_info: bool,
 874    worktrees: &Vec<worktree::Snapshot>,
 875    index_state: Option<&SyntaxIndexState>,
 876) -> predict_edits_v3::PredictEditsRequest {
 877    let mut signatures = Vec::new();
 878    let mut declaration_to_signature_index = HashMap::default();
 879    let mut referenced_declarations = Vec::new();
 880
 881    for snippet in context.snippets {
 882        let project_entry_id = snippet.declaration.project_entry_id();
 883        // TODO: Use full paths (worktree rooted) - need to move full_path method to the snapshot.
 884        // Note that currently full_path is currently being used for excerpt_path.
 885        let Some(path) = worktrees.iter().find_map(|worktree| {
 886            let abs_path = worktree.abs_path();
 887            worktree
 888                .entry_for_id(project_entry_id)
 889                .map(|e| abs_path.join(&e.path))
 890        }) else {
 891            continue;
 892        };
 893
 894        let parent_index = index_state.and_then(|index_state| {
 895            snippet.declaration.parent().and_then(|parent| {
 896                add_signature(
 897                    parent,
 898                    &mut declaration_to_signature_index,
 899                    &mut signatures,
 900                    index_state,
 901                )
 902            })
 903        });
 904
 905        let (text, text_is_truncated) = snippet.declaration.item_text();
 906        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
 907            path,
 908            text: text.into(),
 909            range: snippet.declaration.item_range(),
 910            text_is_truncated,
 911            signature_range: snippet.declaration.signature_range_in_item_text(),
 912            parent_index,
 913            score_components: snippet.score_components,
 914            signature_score: snippet.scores.signature,
 915            declaration_score: snippet.scores.declaration,
 916        });
 917    }
 918
 919    let excerpt_parent = index_state.and_then(|index_state| {
 920        context
 921            .excerpt
 922            .parent_declarations
 923            .last()
 924            .and_then(|(parent, _)| {
 925                add_signature(
 926                    *parent,
 927                    &mut declaration_to_signature_index,
 928                    &mut signatures,
 929                    index_state,
 930                )
 931            })
 932    });
 933
 934    predict_edits_v3::PredictEditsRequest {
 935        excerpt_path,
 936        excerpt: context.excerpt_text.body,
 937        excerpt_range: context.excerpt.range,
 938        cursor_offset: context.cursor_offset_in_excerpt,
 939        referenced_declarations,
 940        signatures,
 941        excerpt_parent,
 942        // todo!
 943        events,
 944        can_collect_data,
 945        diagnostic_groups,
 946        git_info,
 947        debug_info,
 948    }
 949}
 950
 951fn add_signature(
 952    declaration_id: DeclarationId,
 953    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
 954    signatures: &mut Vec<Signature>,
 955    index: &SyntaxIndexState,
 956) -> Option<usize> {
 957    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
 958        return Some(*signature_index);
 959    }
 960    let Some(parent_declaration) = index.declaration(declaration_id) else {
 961        log::error!("bug: missing parent declaration");
 962        return None;
 963    };
 964    let parent_index = parent_declaration.parent().and_then(|parent| {
 965        add_signature(parent, declaration_to_signature_index, signatures, index)
 966    });
 967    let (text, text_is_truncated) = parent_declaration.signature_text();
 968    let signature_index = signatures.len();
 969    signatures.push(Signature {
 970        text: text.into(),
 971        text_is_truncated,
 972        parent_index,
 973    });
 974    declaration_to_signature_index.insert(declaration_id, signature_index);
 975    Some(signature_index)
 976}
 977
 978fn interpolate(
 979    old_snapshot: &BufferSnapshot,
 980    new_snapshot: &BufferSnapshot,
 981    current_edits: Arc<[(Range<Anchor>, String)]>,
 982) -> Option<Vec<(Range<Anchor>, String)>> {
 983    let mut edits = Vec::new();
 984
 985    let mut model_edits = current_edits.iter().peekable();
 986    for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
 987        while let Some((model_old_range, _)) = model_edits.peek() {
 988            let model_old_range = model_old_range.to_offset(old_snapshot);
 989            if model_old_range.end < user_edit.old.start {
 990                let (model_old_range, model_new_text) = model_edits.next().unwrap();
 991                edits.push((model_old_range.clone(), model_new_text.clone()));
 992            } else {
 993                break;
 994            }
 995        }
 996
 997        if let Some((model_old_range, model_new_text)) = model_edits.peek() {
 998            let model_old_offset_range = model_old_range.to_offset(old_snapshot);
 999            if user_edit.old == model_old_offset_range {
1000                let user_new_text = new_snapshot
1001                    .text_for_range(user_edit.new.clone())
1002                    .collect::<String>();
1003
1004                if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
1005                    if !model_suffix.is_empty() {
1006                        let anchor = old_snapshot.anchor_after(user_edit.old.end);
1007                        edits.push((anchor..anchor, model_suffix.to_string()));
1008                    }
1009
1010                    model_edits.next();
1011                    continue;
1012                }
1013            }
1014        }
1015
1016        return None;
1017    }
1018
1019    edits.extend(model_edits.cloned());
1020
1021    if edits.is_empty() { None } else { Some(edits) }
1022}
1023
1024#[cfg(test)]
1025mod tests {
1026    use super::*;
1027    use gpui::TestAppContext;
1028    use language::ToOffset as _;
1029
1030    #[gpui::test]
1031    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1032        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1033        let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1034            to_prediction_edits(
1035                [(2..5, "REM".to_string()), (9..11, "".to_string())],
1036                &buffer,
1037                cx,
1038            )
1039            .into()
1040        });
1041
1042        let edit_preview = cx
1043            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1044            .await;
1045
1046        let prediction = EditPrediction {
1047            id: EditPredictionId(Uuid::new_v4()),
1048            edits,
1049            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1050            edit_preview,
1051        };
1052
1053        cx.update(|cx| {
1054            assert_eq!(
1055                from_prediction_edits(
1056                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1057                    &buffer,
1058                    cx
1059                ),
1060                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1061            );
1062
1063            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1064            assert_eq!(
1065                from_prediction_edits(
1066                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1067                    &buffer,
1068                    cx
1069                ),
1070                vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1071            );
1072
1073            buffer.update(cx, |buffer, cx| buffer.undo(cx));
1074            assert_eq!(
1075                from_prediction_edits(
1076                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1077                    &buffer,
1078                    cx
1079                ),
1080                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1081            );
1082
1083            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1084            assert_eq!(
1085                from_prediction_edits(
1086                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1087                    &buffer,
1088                    cx
1089                ),
1090                vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1091            );
1092
1093            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1094            assert_eq!(
1095                from_prediction_edits(
1096                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1097                    &buffer,
1098                    cx
1099                ),
1100                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1101            );
1102
1103            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1104            assert_eq!(
1105                from_prediction_edits(
1106                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1107                    &buffer,
1108                    cx
1109                ),
1110                vec![(9..11, "".to_string())]
1111            );
1112
1113            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1114            assert_eq!(
1115                from_prediction_edits(
1116                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1117                    &buffer,
1118                    cx
1119                ),
1120                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1121            );
1122
1123            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1124            assert_eq!(
1125                from_prediction_edits(
1126                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1127                    &buffer,
1128                    cx
1129                ),
1130                vec![(4..4, "M".to_string())]
1131            );
1132
1133            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1134            assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1135        })
1136    }
1137
1138    fn to_prediction_edits(
1139        iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1140        buffer: &Entity<Buffer>,
1141        cx: &App,
1142    ) -> Vec<(Range<Anchor>, String)> {
1143        let buffer = buffer.read(cx);
1144        iterator
1145            .into_iter()
1146            .map(|(range, text)| {
1147                (
1148                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1149                    text,
1150                )
1151            })
1152            .collect()
1153    }
1154
1155    fn from_prediction_edits(
1156        editor_edits: &[(Range<Anchor>, String)],
1157        buffer: &Entity<Buffer>,
1158        cx: &App,
1159    ) -> Vec<(Range<usize>, String)> {
1160        let buffer = buffer.read(cx);
1161        editor_edits
1162            .iter()
1163            .map(|(range, text)| {
1164                (
1165                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
1166                    text.clone(),
1167                )
1168            })
1169            .collect()
1170    }
1171}