zeta2.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use arrayvec::ArrayVec;
   3use chrono::TimeDelta;
   4use client::{Client, EditPredictionUsage, UserStore};
   5use cloud_llm_client::predict_edits_v3::{self, Signature};
   6use cloud_llm_client::{
   7    EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
   8};
   9use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
  10use edit_prediction_context::{
  11    DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
  12    SyntaxIndexState,
  13};
  14use futures::AsyncReadExt as _;
  15use futures::channel::mpsc;
  16use gpui::http_client::Method;
  17use gpui::{
  18    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  19    http_client, prelude::*,
  20};
  21use language::{Anchor, Buffer, OffsetRangeExt as _, ToPoint};
  22use language::{BufferSnapshot, EditPreview};
  23use language_model::{LlmApiToken, RefreshLlmTokenListener};
  24use project::Project;
  25use release_channel::AppVersion;
  26use std::cmp;
  27use std::collections::{HashMap, VecDeque, hash_map};
  28use std::path::PathBuf;
  29use std::str::FromStr as _;
  30use std::time::{Duration, Instant};
  31use std::{ops::Range, sync::Arc};
  32use thiserror::Error;
  33use util::{ResultExt as _, some_or_debug_panic};
  34use uuid::Uuid;
  35use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  36
  37const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  38
  39/// Maximum number of events to track.
  40const MAX_EVENT_COUNT: usize = 16;
  41
  42pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
  43    max_bytes: 512,
  44    min_bytes: 128,
  45    target_before_cursor_over_total_bytes: 0.5,
  46};
  47
  48#[derive(Clone)]
  49struct ZetaGlobal(Entity<Zeta>);
  50
  51impl Global for ZetaGlobal {}
  52
  53pub struct Zeta {
  54    client: Arc<Client>,
  55    user_store: Entity<UserStore>,
  56    llm_token: LlmApiToken,
  57    _llm_token_subscription: Subscription,
  58    projects: HashMap<EntityId, ZetaProject>,
  59    pub excerpt_options: EditPredictionExcerptOptions,
  60    update_required: bool,
  61    debug_tx: Option<mpsc::UnboundedSender<Result<PredictionDebugInfo, String>>>,
  62}
  63
  64pub struct PredictionDebugInfo {
  65    pub context: EditPredictionContext,
  66    pub retrieval_time: TimeDelta,
  67    pub request: RequestDebugInfo,
  68    pub buffer: WeakEntity<Buffer>,
  69    pub position: language::Anchor,
  70}
  71
  72pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
  73
  74struct ZetaProject {
  75    syntax_index: Entity<SyntaxIndex>,
  76    events: VecDeque<Event>,
  77    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
  78}
  79
  80struct RegisteredBuffer {
  81    snapshot: BufferSnapshot,
  82    _subscriptions: [gpui::Subscription; 2],
  83}
  84
  85#[derive(Clone)]
  86pub enum Event {
  87    BufferChange {
  88        old_snapshot: BufferSnapshot,
  89        new_snapshot: BufferSnapshot,
  90        timestamp: Instant,
  91    },
  92}
  93
  94impl Zeta {
  95    pub fn global(
  96        client: &Arc<Client>,
  97        user_store: &Entity<UserStore>,
  98        cx: &mut App,
  99    ) -> Entity<Self> {
 100        cx.try_global::<ZetaGlobal>()
 101            .map(|global| global.0.clone())
 102            .unwrap_or_else(|| {
 103                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 104                cx.set_global(ZetaGlobal(zeta.clone()));
 105                zeta
 106            })
 107    }
 108
 109    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 110        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 111
 112        Self {
 113            projects: HashMap::new(),
 114            client,
 115            user_store,
 116            excerpt_options: DEFAULT_EXCERPT_OPTIONS,
 117            llm_token: LlmApiToken::default(),
 118            _llm_token_subscription: cx.subscribe(
 119                &refresh_llm_token_listener,
 120                |this, _listener, _event, cx| {
 121                    let client = this.client.clone();
 122                    let llm_token = this.llm_token.clone();
 123                    cx.spawn(async move |_this, _cx| {
 124                        llm_token.refresh(&client).await?;
 125                        anyhow::Ok(())
 126                    })
 127                    .detach_and_log_err(cx);
 128                },
 129            ),
 130            update_required: false,
 131            debug_tx: None,
 132        }
 133    }
 134
 135    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<Result<PredictionDebugInfo, String>> {
 136        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 137        self.debug_tx = Some(debug_watch_tx);
 138        debug_watch_rx
 139    }
 140
 141    pub fn excerpt_options(&self) -> &EditPredictionExcerptOptions {
 142        &self.excerpt_options
 143    }
 144
 145    pub fn set_excerpt_options(&mut self, options: EditPredictionExcerptOptions) {
 146        self.excerpt_options = options;
 147    }
 148
 149    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 150        self.user_store.read(cx).edit_prediction_usage()
 151    }
 152
 153    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 154        self.get_or_init_zeta_project(project, cx);
 155    }
 156
 157    pub fn register_buffer(
 158        &mut self,
 159        buffer: &Entity<Buffer>,
 160        project: &Entity<Project>,
 161        cx: &mut Context<Self>,
 162    ) {
 163        let zeta_project = self.get_or_init_zeta_project(project, cx);
 164        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 165    }
 166
 167    fn get_or_init_zeta_project(
 168        &mut self,
 169        project: &Entity<Project>,
 170        cx: &mut App,
 171    ) -> &mut ZetaProject {
 172        self.projects
 173            .entry(project.entity_id())
 174            .or_insert_with(|| ZetaProject {
 175                syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
 176                events: VecDeque::new(),
 177                registered_buffers: HashMap::new(),
 178            })
 179    }
 180
 181    fn register_buffer_impl<'a>(
 182        zeta_project: &'a mut ZetaProject,
 183        buffer: &Entity<Buffer>,
 184        project: &Entity<Project>,
 185        cx: &mut Context<Self>,
 186    ) -> &'a mut RegisteredBuffer {
 187        let buffer_id = buffer.entity_id();
 188        match zeta_project.registered_buffers.entry(buffer_id) {
 189            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 190            hash_map::Entry::Vacant(entry) => {
 191                let snapshot = buffer.read(cx).snapshot();
 192                let project_entity_id = project.entity_id();
 193                entry.insert(RegisteredBuffer {
 194                    snapshot,
 195                    _subscriptions: [
 196                        cx.subscribe(buffer, {
 197                            let project = project.downgrade();
 198                            move |this, buffer, event, cx| {
 199                                if let language::BufferEvent::Edited = event
 200                                    && let Some(project) = project.upgrade()
 201                                {
 202                                    this.report_changes_for_buffer(&buffer, &project, cx);
 203                                }
 204                            }
 205                        }),
 206                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 207                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 208                            else {
 209                                return;
 210                            };
 211                            zeta_project.registered_buffers.remove(&buffer_id);
 212                        }),
 213                    ],
 214                })
 215            }
 216        }
 217    }
 218
 219    fn report_changes_for_buffer(
 220        &mut self,
 221        buffer: &Entity<Buffer>,
 222        project: &Entity<Project>,
 223        cx: &mut Context<Self>,
 224    ) -> BufferSnapshot {
 225        let zeta_project = self.get_or_init_zeta_project(project, cx);
 226        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 227
 228        let new_snapshot = buffer.read(cx).snapshot();
 229        if new_snapshot.version != registered_buffer.snapshot.version {
 230            let old_snapshot =
 231                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 232            Self::push_event(
 233                zeta_project,
 234                Event::BufferChange {
 235                    old_snapshot,
 236                    new_snapshot: new_snapshot.clone(),
 237                    timestamp: Instant::now(),
 238                },
 239            );
 240        }
 241
 242        new_snapshot
 243    }
 244
 245    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
 246        let events = &mut zeta_project.events;
 247
 248        if let Some(Event::BufferChange {
 249            new_snapshot: last_new_snapshot,
 250            timestamp: last_timestamp,
 251            ..
 252        }) = events.back_mut()
 253        {
 254            // Coalesce edits for the same buffer when they happen one after the other.
 255            let Event::BufferChange {
 256                old_snapshot,
 257                new_snapshot,
 258                timestamp,
 259            } = &event;
 260
 261            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 262                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 263                && old_snapshot.version == last_new_snapshot.version
 264            {
 265                *last_new_snapshot = new_snapshot.clone();
 266                *last_timestamp = *timestamp;
 267                return;
 268            }
 269        }
 270
 271        if events.len() >= MAX_EVENT_COUNT {
 272            // These are halved instead of popping to improve prompt caching.
 273            events.drain(..MAX_EVENT_COUNT / 2);
 274        }
 275
 276        events.push_back(event);
 277    }
 278
 279    pub fn request_prediction(
 280        &mut self,
 281        project: &Entity<Project>,
 282        buffer: &Entity<Buffer>,
 283        position: language::Anchor,
 284        cx: &mut Context<Self>,
 285    ) -> Task<Result<Option<EditPrediction>>> {
 286        let project_state = self.projects.get(&project.entity_id());
 287
 288        let index_state = project_state.map(|state| {
 289            state
 290                .syntax_index
 291                .read_with(cx, |index, _cx| index.state().clone())
 292        });
 293        let excerpt_options = self.excerpt_options.clone();
 294        let snapshot = buffer.read(cx).snapshot();
 295        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
 296            return Task::ready(Err(anyhow!("No file path for excerpt")));
 297        };
 298        let client = self.client.clone();
 299        let llm_token = self.llm_token.clone();
 300        let app_version = AppVersion::global(cx);
 301        let worktree_snapshots = project
 302            .read(cx)
 303            .worktrees(cx)
 304            .map(|worktree| worktree.read(cx).snapshot())
 305            .collect::<Vec<_>>();
 306        let debug_tx = self.debug_tx.clone();
 307
 308        let request_task = cx.background_spawn({
 309            let snapshot = snapshot.clone();
 310            let buffer = buffer.clone();
 311            async move {
 312                let index_state = if let Some(index_state) = index_state {
 313                    Some(index_state.lock_owned().await)
 314                } else {
 315                    None
 316                };
 317
 318                let cursor_point = position.to_point(&snapshot);
 319
 320                let before_retrieval = chrono::Utc::now();
 321
 322                let Some(context) = EditPredictionContext::gather_context(
 323                    cursor_point,
 324                    &snapshot,
 325                    &excerpt_options,
 326                    index_state.as_deref(),
 327                ) else {
 328                    return Ok(None);
 329                };
 330
 331                let debug_context = if let Some(debug_tx) = debug_tx {
 332                    Some((debug_tx, context.clone()))
 333                } else {
 334                    None
 335                };
 336
 337                let request = make_cloud_request(
 338                    excerpt_path.clone(),
 339                    context,
 340                    // TODO pass everything
 341                    Vec::new(),
 342                    false,
 343                    Vec::new(),
 344                    None,
 345                    debug_context.is_some(),
 346                    &worktree_snapshots,
 347                    index_state.as_deref(),
 348                );
 349
 350                let retrieval_time = chrono::Utc::now() - before_retrieval;
 351                let response = Self::perform_request(client, llm_token, app_version, request).await;
 352
 353                if let Some((debug_tx, context)) = debug_context {
 354                    debug_tx
 355                        .unbounded_send(response.as_ref().map_err(|err| err.to_string()).and_then(
 356                            |response| {
 357                                let Some(request) =
 358                                    some_or_debug_panic(response.0.debug_info.clone())
 359                                else {
 360                                    return Err("Missing debug info".to_string());
 361                                };
 362                                Ok(PredictionDebugInfo {
 363                                    context,
 364                                    request,
 365                                    retrieval_time,
 366                                    buffer: buffer.downgrade(),
 367                                    position,
 368                                })
 369                            },
 370                        ))
 371                        .ok();
 372                }
 373
 374                anyhow::Ok(Some(response?))
 375            }
 376        });
 377
 378        let buffer = buffer.clone();
 379
 380        cx.spawn(async move |this, cx| {
 381            match request_task.await {
 382                Ok(Some((response, usage))) => {
 383                    log::debug!("predicted edits: {:?}", &response.edits);
 384
 385                    if let Some(usage) = usage {
 386                        this.update(cx, |this, cx| {
 387                            this.user_store.update(cx, |user_store, cx| {
 388                                user_store.update_edit_prediction_usage(usage, cx);
 389                            });
 390                        })
 391                        .ok();
 392                    }
 393
 394                    // TODO telemetry: duration, etc
 395
 396                    // TODO produce smaller edits by diffing against snapshot first
 397                    //
 398                    // Cloud returns entire snippets/excerpts ranges as they were included
 399                    // in the request, but we should display smaller edits to the user.
 400                    //
 401                    // We can do this by computing a diff of each one against the snapshot.
 402                    // Similar to zeta::Zeta::compute_edits, but per edit.
 403                    let edits = response
 404                        .edits
 405                        .into_iter()
 406                        .map(|edit| {
 407                            // TODO edits to different files
 408                            (
 409                                snapshot.anchor_before(edit.range.start)
 410                                    ..snapshot.anchor_before(edit.range.end),
 411                                edit.content,
 412                            )
 413                        })
 414                        .collect::<Vec<_>>()
 415                        .into();
 416
 417                    let Some((edits, snapshot, edit_preview_task)) =
 418                        buffer.read_with(cx, |buffer, cx| {
 419                            let new_snapshot = buffer.snapshot();
 420                            let edits: Arc<[_]> =
 421                                interpolate(&snapshot, &new_snapshot, edits)?.into();
 422                            Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 423                        })?
 424                    else {
 425                        return Ok(None);
 426                    };
 427
 428                    Ok(Some(EditPrediction {
 429                        id: EditPredictionId(response.request_id),
 430                        edits,
 431                        snapshot,
 432                        edit_preview: edit_preview_task.await,
 433                    }))
 434                }
 435                Ok(None) => Ok(None),
 436                Err(err) => {
 437                    if err.is::<ZedUpdateRequiredError>() {
 438                        cx.update(|cx| {
 439                            this.update(cx, |this, _cx| {
 440                                this.update_required = true;
 441                            })
 442                            .ok();
 443
 444                            let error_message: SharedString = err.to_string().into();
 445                            show_app_notification(
 446                                NotificationId::unique::<ZedUpdateRequiredError>(),
 447                                cx,
 448                                move |cx| {
 449                                    cx.new(|cx| {
 450                                        ErrorMessagePrompt::new(error_message.clone(), cx)
 451                                            .with_link_button(
 452                                                "Update Zed",
 453                                                "https://zed.dev/releases",
 454                                            )
 455                                    })
 456                                },
 457                            );
 458                        })
 459                        .ok();
 460                    }
 461
 462                    Err(err)
 463                }
 464            }
 465        })
 466    }
 467
 468    async fn perform_request(
 469        client: Arc<Client>,
 470        llm_token: LlmApiToken,
 471        app_version: SemanticVersion,
 472        request: predict_edits_v3::PredictEditsRequest,
 473    ) -> Result<(
 474        predict_edits_v3::PredictEditsResponse,
 475        Option<EditPredictionUsage>,
 476    )> {
 477        let http_client = client.http_client();
 478        let mut token = llm_token.acquire(&client).await?;
 479        let mut did_retry = false;
 480
 481        loop {
 482            let request_builder = http_client::Request::builder().method(Method::POST);
 483            let request_builder =
 484                if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 485                    request_builder.uri(predict_edits_url)
 486                } else {
 487                    request_builder.uri(
 488                        http_client
 489                            .build_zed_llm_url("/predict_edits/v3", &[])?
 490                            .as_ref(),
 491                    )
 492                };
 493            let request = request_builder
 494                .header("Content-Type", "application/json")
 495                .header("Authorization", format!("Bearer {}", token))
 496                .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 497                .body(serde_json::to_string(&request)?.into())?;
 498
 499            let mut response = http_client.send(request).await?;
 500
 501            if let Some(minimum_required_version) = response
 502                .headers()
 503                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 504                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 505            {
 506                anyhow::ensure!(
 507                    app_version >= minimum_required_version,
 508                    ZedUpdateRequiredError {
 509                        minimum_version: minimum_required_version
 510                    }
 511                );
 512            }
 513
 514            if response.status().is_success() {
 515                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 516
 517                let mut body = Vec::new();
 518                response.body_mut().read_to_end(&mut body).await?;
 519                return Ok((serde_json::from_slice(&body)?, usage));
 520            } else if !did_retry
 521                && response
 522                    .headers()
 523                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 524                    .is_some()
 525            {
 526                did_retry = true;
 527                token = llm_token.refresh(&client).await?;
 528            } else {
 529                let mut body = String::new();
 530                response.body_mut().read_to_string(&mut body).await?;
 531                anyhow::bail!(
 532                    "error predicting edits.\nStatus: {:?}\nBody: {}",
 533                    response.status(),
 534                    body
 535                );
 536            }
 537        }
 538    }
 539
 540    // TODO: Dedupe with similar code in request_prediction?
 541    pub fn cloud_request_for_zeta_cli(
 542        &mut self,
 543        project: &Entity<Project>,
 544        buffer: &Entity<Buffer>,
 545        position: language::Anchor,
 546        cx: &mut Context<Self>,
 547    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
 548        let project_state = self.projects.get(&project.entity_id());
 549
 550        let index_state = project_state.map(|state| {
 551            state
 552                .syntax_index
 553                .read_with(cx, |index, _cx| index.state().clone())
 554        });
 555        let excerpt_options = self.excerpt_options.clone();
 556        let snapshot = buffer.read(cx).snapshot();
 557        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
 558            return Task::ready(Err(anyhow!("No file path for excerpt")));
 559        };
 560        let worktree_snapshots = project
 561            .read(cx)
 562            .worktrees(cx)
 563            .map(|worktree| worktree.read(cx).snapshot())
 564            .collect::<Vec<_>>();
 565
 566        cx.background_spawn(async move {
 567            let index_state = if let Some(index_state) = index_state {
 568                Some(index_state.lock_owned().await)
 569            } else {
 570                None
 571            };
 572
 573            let cursor_point = position.to_point(&snapshot);
 574
 575            let debug_info = true;
 576            EditPredictionContext::gather_context(
 577                cursor_point,
 578                &snapshot,
 579                &excerpt_options,
 580                index_state.as_deref(),
 581            )
 582            .context("Failed to select excerpt")
 583            .map(|context| {
 584                make_cloud_request(
 585                    excerpt_path.clone(),
 586                    context,
 587                    // TODO pass everything
 588                    Vec::new(),
 589                    false,
 590                    Vec::new(),
 591                    None,
 592                    debug_info,
 593                    &worktree_snapshots,
 594                    index_state.as_deref(),
 595                )
 596            })
 597        })
 598    }
 599}
 600
 601#[derive(Error, Debug)]
 602#[error(
 603    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
 604)]
 605pub struct ZedUpdateRequiredError {
 606    minimum_version: SemanticVersion,
 607}
 608
 609pub struct ZetaEditPredictionProvider {
 610    zeta: Entity<Zeta>,
 611    current_prediction: Option<CurrentEditPrediction>,
 612    next_pending_prediction_id: usize,
 613    pending_predictions: ArrayVec<PendingPrediction, 2>,
 614    last_request_timestamp: Instant,
 615}
 616
 617impl ZetaEditPredictionProvider {
 618    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
 619
 620    pub fn new(
 621        project: Option<&Entity<Project>>,
 622        client: &Arc<Client>,
 623        user_store: &Entity<UserStore>,
 624        cx: &mut App,
 625    ) -> Self {
 626        let zeta = Zeta::global(client, user_store, cx);
 627        if let Some(project) = project {
 628            zeta.update(cx, |zeta, cx| {
 629                zeta.register_project(project, cx);
 630            });
 631        }
 632
 633        Self {
 634            zeta,
 635            current_prediction: None,
 636            next_pending_prediction_id: 0,
 637            pending_predictions: ArrayVec::new(),
 638            last_request_timestamp: Instant::now(),
 639        }
 640    }
 641}
 642
 643#[derive(Clone)]
 644struct CurrentEditPrediction {
 645    buffer_id: EntityId,
 646    prediction: EditPrediction,
 647}
 648
 649impl CurrentEditPrediction {
 650    fn should_replace_prediction(&self, old_prediction: &Self, snapshot: &BufferSnapshot) -> bool {
 651        if self.buffer_id != old_prediction.buffer_id {
 652            return true;
 653        }
 654
 655        let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
 656            return true;
 657        };
 658        let Some(new_edits) = self.prediction.interpolate(snapshot) else {
 659            return false;
 660        };
 661
 662        if old_edits.len() == 1 && new_edits.len() == 1 {
 663            let (old_range, old_text) = &old_edits[0];
 664            let (new_range, new_text) = &new_edits[0];
 665            new_range == old_range && new_text.starts_with(old_text)
 666        } else {
 667            true
 668        }
 669    }
 670}
 671
 672#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
 673pub struct EditPredictionId(Uuid);
 674
 675impl From<EditPredictionId> for gpui::ElementId {
 676    fn from(value: EditPredictionId) -> Self {
 677        gpui::ElementId::Uuid(value.0)
 678    }
 679}
 680
 681impl std::fmt::Display for EditPredictionId {
 682    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 683        write!(f, "{}", self.0)
 684    }
 685}
 686
 687#[derive(Clone)]
 688pub struct EditPrediction {
 689    id: EditPredictionId,
 690    edits: Arc<[(Range<Anchor>, String)]>,
 691    snapshot: BufferSnapshot,
 692    edit_preview: EditPreview,
 693}
 694
 695impl EditPrediction {
 696    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
 697        interpolate(&self.snapshot, new_snapshot, self.edits.clone())
 698    }
 699}
 700
 701struct PendingPrediction {
 702    id: usize,
 703    _task: Task<()>,
 704}
 705
 706impl EditPredictionProvider for ZetaEditPredictionProvider {
 707    fn name() -> &'static str {
 708        "zed-predict2"
 709    }
 710
 711    fn display_name() -> &'static str {
 712        "Zed's Edit Predictions 2"
 713    }
 714
 715    fn show_completions_in_menu() -> bool {
 716        true
 717    }
 718
 719    fn show_tab_accept_marker() -> bool {
 720        true
 721    }
 722
 723    fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
 724        // TODO [zeta2]
 725        DataCollectionState::Unsupported
 726    }
 727
 728    fn toggle_data_collection(&mut self, _cx: &mut App) {
 729        // TODO [zeta2]
 730    }
 731
 732    fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
 733        self.zeta.read(cx).usage(cx)
 734    }
 735
 736    fn is_enabled(
 737        &self,
 738        _buffer: &Entity<language::Buffer>,
 739        _cursor_position: language::Anchor,
 740        _cx: &App,
 741    ) -> bool {
 742        true
 743    }
 744
 745    fn is_refreshing(&self) -> bool {
 746        !self.pending_predictions.is_empty()
 747    }
 748
 749    fn refresh(
 750        &mut self,
 751        project: Option<Entity<project::Project>>,
 752        buffer: Entity<language::Buffer>,
 753        cursor_position: language::Anchor,
 754        _debounce: bool,
 755        cx: &mut Context<Self>,
 756    ) {
 757        let Some(project) = project else {
 758            return;
 759        };
 760
 761        if self
 762            .zeta
 763            .read(cx)
 764            .user_store
 765            .read_with(cx, |user_store, _cx| {
 766                user_store.account_too_young() || user_store.has_overdue_invoices()
 767            })
 768        {
 769            return;
 770        }
 771
 772        if let Some(current_prediction) = self.current_prediction.as_ref() {
 773            let snapshot = buffer.read(cx).snapshot();
 774            if current_prediction
 775                .prediction
 776                .interpolate(&snapshot)
 777                .is_some()
 778            {
 779                return;
 780            }
 781        }
 782
 783        let pending_prediction_id = self.next_pending_prediction_id;
 784        self.next_pending_prediction_id += 1;
 785        let last_request_timestamp = self.last_request_timestamp;
 786
 787        let task = cx.spawn(async move |this, cx| {
 788            if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
 789                .checked_duration_since(Instant::now())
 790            {
 791                cx.background_executor().timer(timeout).await;
 792            }
 793
 794            let prediction_request = this.update(cx, |this, cx| {
 795                this.last_request_timestamp = Instant::now();
 796                this.zeta.update(cx, |zeta, cx| {
 797                    zeta.request_prediction(&project, &buffer, cursor_position, cx)
 798                })
 799            });
 800
 801            let prediction = match prediction_request {
 802                Ok(prediction_request) => {
 803                    let prediction_request = prediction_request.await;
 804                    prediction_request.map(|c| {
 805                        c.map(|prediction| CurrentEditPrediction {
 806                            buffer_id: buffer.entity_id(),
 807                            prediction,
 808                        })
 809                    })
 810                }
 811                Err(error) => Err(error),
 812            };
 813
 814            this.update(cx, |this, cx| {
 815                if this.pending_predictions[0].id == pending_prediction_id {
 816                    this.pending_predictions.remove(0);
 817                } else {
 818                    this.pending_predictions.clear();
 819                }
 820
 821                let Some(new_prediction) = prediction
 822                    .context("edit prediction failed")
 823                    .log_err()
 824                    .flatten()
 825                else {
 826                    cx.notify();
 827                    return;
 828                };
 829
 830                if let Some(old_prediction) = this.current_prediction.as_ref() {
 831                    let snapshot = buffer.read(cx).snapshot();
 832                    if new_prediction.should_replace_prediction(old_prediction, &snapshot) {
 833                        this.current_prediction = Some(new_prediction);
 834                    }
 835                } else {
 836                    this.current_prediction = Some(new_prediction);
 837                }
 838
 839                cx.notify();
 840            })
 841            .ok();
 842        });
 843
 844        // We always maintain at most two pending predictions. When we already
 845        // have two, we replace the newest one.
 846        if self.pending_predictions.len() <= 1 {
 847            self.pending_predictions.push(PendingPrediction {
 848                id: pending_prediction_id,
 849                _task: task,
 850            });
 851        } else if self.pending_predictions.len() == 2 {
 852            self.pending_predictions.pop();
 853            self.pending_predictions.push(PendingPrediction {
 854                id: pending_prediction_id,
 855                _task: task,
 856            });
 857        }
 858
 859        cx.notify();
 860    }
 861
 862    fn cycle(
 863        &mut self,
 864        _buffer: Entity<language::Buffer>,
 865        _cursor_position: language::Anchor,
 866        _direction: Direction,
 867        _cx: &mut Context<Self>,
 868    ) {
 869    }
 870
 871    fn accept(&mut self, _cx: &mut Context<Self>) {
 872        // TODO [zeta2] report accept
 873        self.current_prediction.take();
 874        self.pending_predictions.clear();
 875    }
 876
 877    fn discard(&mut self, _cx: &mut Context<Self>) {
 878        self.pending_predictions.clear();
 879        self.current_prediction.take();
 880    }
 881
 882    fn suggest(
 883        &mut self,
 884        buffer: &Entity<language::Buffer>,
 885        cursor_position: language::Anchor,
 886        cx: &mut Context<Self>,
 887    ) -> Option<edit_prediction::EditPrediction> {
 888        let CurrentEditPrediction {
 889            buffer_id,
 890            prediction,
 891            ..
 892        } = self.current_prediction.as_mut()?;
 893
 894        // Invalidate previous prediction if it was generated for a different buffer.
 895        if *buffer_id != buffer.entity_id() {
 896            self.current_prediction.take();
 897            return None;
 898        }
 899
 900        let buffer = buffer.read(cx);
 901        let Some(edits) = prediction.interpolate(&buffer.snapshot()) else {
 902            self.current_prediction.take();
 903            return None;
 904        };
 905
 906        let cursor_row = cursor_position.to_point(buffer).row;
 907        let (closest_edit_ix, (closest_edit_range, _)) =
 908            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
 909                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
 910                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
 911                cmp::min(distance_from_start, distance_from_end)
 912            })?;
 913
 914        let mut edit_start_ix = closest_edit_ix;
 915        for (range, _) in edits[..edit_start_ix].iter().rev() {
 916            let distance_from_closest_edit =
 917                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
 918            if distance_from_closest_edit <= 1 {
 919                edit_start_ix -= 1;
 920            } else {
 921                break;
 922            }
 923        }
 924
 925        let mut edit_end_ix = closest_edit_ix + 1;
 926        for (range, _) in &edits[edit_end_ix..] {
 927            let distance_from_closest_edit =
 928                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
 929            if distance_from_closest_edit <= 1 {
 930                edit_end_ix += 1;
 931            } else {
 932                break;
 933            }
 934        }
 935
 936        Some(edit_prediction::EditPrediction {
 937            id: Some(prediction.id.to_string().into()),
 938            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
 939            edit_preview: Some(prediction.edit_preview.clone()),
 940        })
 941    }
 942}
 943
 944fn make_cloud_request(
 945    excerpt_path: PathBuf,
 946    context: EditPredictionContext,
 947    events: Vec<predict_edits_v3::Event>,
 948    can_collect_data: bool,
 949    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
 950    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
 951    debug_info: bool,
 952    worktrees: &Vec<worktree::Snapshot>,
 953    index_state: Option<&SyntaxIndexState>,
 954) -> predict_edits_v3::PredictEditsRequest {
 955    let mut signatures = Vec::new();
 956    let mut declaration_to_signature_index = HashMap::default();
 957    let mut referenced_declarations = Vec::new();
 958
 959    for snippet in context.snippets {
 960        let project_entry_id = snippet.declaration.project_entry_id();
 961        let Some(path) = worktrees.iter().find_map(|worktree| {
 962            worktree.entry_for_id(project_entry_id).map(|entry| {
 963                let mut full_path = PathBuf::new();
 964                full_path.push(worktree.root_name());
 965                full_path.push(&entry.path);
 966                full_path
 967            })
 968        }) else {
 969            continue;
 970        };
 971
 972        let parent_index = index_state.and_then(|index_state| {
 973            snippet.declaration.parent().and_then(|parent| {
 974                add_signature(
 975                    parent,
 976                    &mut declaration_to_signature_index,
 977                    &mut signatures,
 978                    index_state,
 979                )
 980            })
 981        });
 982
 983        let (text, text_is_truncated) = snippet.declaration.item_text();
 984        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
 985            path,
 986            text: text.into(),
 987            range: snippet.declaration.item_range(),
 988            text_is_truncated,
 989            signature_range: snippet.declaration.signature_range_in_item_text(),
 990            parent_index,
 991            score_components: snippet.score_components,
 992            signature_score: snippet.scores.signature,
 993            declaration_score: snippet.scores.declaration,
 994        });
 995    }
 996
 997    let excerpt_parent = index_state.and_then(|index_state| {
 998        context
 999            .excerpt
1000            .parent_declarations
1001            .last()
1002            .and_then(|(parent, _)| {
1003                add_signature(
1004                    *parent,
1005                    &mut declaration_to_signature_index,
1006                    &mut signatures,
1007                    index_state,
1008                )
1009            })
1010    });
1011
1012    predict_edits_v3::PredictEditsRequest {
1013        excerpt_path,
1014        excerpt: context.excerpt_text.body,
1015        excerpt_range: context.excerpt.range,
1016        cursor_offset: context.cursor_offset_in_excerpt,
1017        referenced_declarations,
1018        signatures,
1019        excerpt_parent,
1020        events,
1021        can_collect_data,
1022        diagnostic_groups,
1023        git_info,
1024        debug_info,
1025    }
1026}
1027
1028fn add_signature(
1029    declaration_id: DeclarationId,
1030    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1031    signatures: &mut Vec<Signature>,
1032    index: &SyntaxIndexState,
1033) -> Option<usize> {
1034    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1035        return Some(*signature_index);
1036    }
1037    let Some(parent_declaration) = index.declaration(declaration_id) else {
1038        log::error!("bug: missing parent declaration");
1039        return None;
1040    };
1041    let parent_index = parent_declaration.parent().and_then(|parent| {
1042        add_signature(parent, declaration_to_signature_index, signatures, index)
1043    });
1044    let (text, text_is_truncated) = parent_declaration.signature_text();
1045    let signature_index = signatures.len();
1046    signatures.push(Signature {
1047        text: text.into(),
1048        text_is_truncated,
1049        parent_index,
1050        range: parent_declaration.signature_range(),
1051    });
1052    declaration_to_signature_index.insert(declaration_id, signature_index);
1053    Some(signature_index)
1054}
1055
1056fn interpolate(
1057    old_snapshot: &BufferSnapshot,
1058    new_snapshot: &BufferSnapshot,
1059    current_edits: Arc<[(Range<Anchor>, String)]>,
1060) -> Option<Vec<(Range<Anchor>, String)>> {
1061    let mut edits = Vec::new();
1062
1063    let mut model_edits = current_edits.iter().peekable();
1064    for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
1065        while let Some((model_old_range, _)) = model_edits.peek() {
1066            let model_old_range = model_old_range.to_offset(old_snapshot);
1067            if model_old_range.end < user_edit.old.start {
1068                let (model_old_range, model_new_text) = model_edits.next().unwrap();
1069                edits.push((model_old_range.clone(), model_new_text.clone()));
1070            } else {
1071                break;
1072            }
1073        }
1074
1075        if let Some((model_old_range, model_new_text)) = model_edits.peek() {
1076            let model_old_offset_range = model_old_range.to_offset(old_snapshot);
1077            if user_edit.old == model_old_offset_range {
1078                let user_new_text = new_snapshot
1079                    .text_for_range(user_edit.new.clone())
1080                    .collect::<String>();
1081
1082                if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
1083                    if !model_suffix.is_empty() {
1084                        let anchor = old_snapshot.anchor_after(user_edit.old.end);
1085                        edits.push((anchor..anchor, model_suffix.to_string()));
1086                    }
1087
1088                    model_edits.next();
1089                    continue;
1090                }
1091            }
1092        }
1093
1094        return None;
1095    }
1096
1097    edits.extend(model_edits.cloned());
1098
1099    if edits.is_empty() { None } else { Some(edits) }
1100}
1101
1102#[cfg(test)]
1103mod tests {
1104    use super::*;
1105    use gpui::TestAppContext;
1106    use language::ToOffset as _;
1107
1108    #[gpui::test]
1109    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1110        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1111        let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1112            to_prediction_edits(
1113                [(2..5, "REM".to_string()), (9..11, "".to_string())],
1114                &buffer,
1115                cx,
1116            )
1117            .into()
1118        });
1119
1120        let edit_preview = cx
1121            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1122            .await;
1123
1124        let prediction = EditPrediction {
1125            id: EditPredictionId(Uuid::new_v4()),
1126            edits,
1127            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1128            edit_preview,
1129        };
1130
1131        cx.update(|cx| {
1132            assert_eq!(
1133                from_prediction_edits(
1134                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1135                    &buffer,
1136                    cx
1137                ),
1138                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1139            );
1140
1141            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1142            assert_eq!(
1143                from_prediction_edits(
1144                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1145                    &buffer,
1146                    cx
1147                ),
1148                vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1149            );
1150
1151            buffer.update(cx, |buffer, cx| buffer.undo(cx));
1152            assert_eq!(
1153                from_prediction_edits(
1154                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1155                    &buffer,
1156                    cx
1157                ),
1158                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1159            );
1160
1161            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1162            assert_eq!(
1163                from_prediction_edits(
1164                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1165                    &buffer,
1166                    cx
1167                ),
1168                vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1169            );
1170
1171            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1172            assert_eq!(
1173                from_prediction_edits(
1174                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1175                    &buffer,
1176                    cx
1177                ),
1178                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1179            );
1180
1181            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1182            assert_eq!(
1183                from_prediction_edits(
1184                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1185                    &buffer,
1186                    cx
1187                ),
1188                vec![(9..11, "".to_string())]
1189            );
1190
1191            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1192            assert_eq!(
1193                from_prediction_edits(
1194                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1195                    &buffer,
1196                    cx
1197                ),
1198                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1199            );
1200
1201            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1202            assert_eq!(
1203                from_prediction_edits(
1204                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1205                    &buffer,
1206                    cx
1207                ),
1208                vec![(4..4, "M".to_string())]
1209            );
1210
1211            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1212            assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1213        })
1214    }
1215
1216    fn to_prediction_edits(
1217        iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1218        buffer: &Entity<Buffer>,
1219        cx: &App,
1220    ) -> Vec<(Range<Anchor>, String)> {
1221        let buffer = buffer.read(cx);
1222        iterator
1223            .into_iter()
1224            .map(|(range, text)| {
1225                (
1226                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1227                    text,
1228                )
1229            })
1230            .collect()
1231    }
1232
1233    fn from_prediction_edits(
1234        editor_edits: &[(Range<Anchor>, String)],
1235        buffer: &Entity<Buffer>,
1236        cx: &App,
1237    ) -> Vec<(Range<usize>, String)> {
1238        let buffer = buffer.read(cx);
1239        editor_edits
1240            .iter()
1241            .map(|(range, text)| {
1242                (
1243                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
1244                    text.clone(),
1245                )
1246            })
1247            .collect()
1248    }
1249}