zeta2.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use arrayvec::ArrayVec;
   3use chrono::TimeDelta;
   4use client::{Client, EditPredictionUsage, UserStore};
   5use cloud_llm_client::predict_edits_v3::{self, Signature};
   6use cloud_llm_client::{
   7    EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
   8};
   9use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
  10use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
  11use edit_prediction_context::{
  12    DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
  13    SyntaxIndexState,
  14};
  15use futures::AsyncReadExt as _;
  16use futures::channel::mpsc;
  17use gpui::http_client::Method;
  18use gpui::{
  19    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  20    http_client, prelude::*,
  21};
  22use language::{
  23    Anchor, Buffer, DiagnosticSet, LanguageServerId, OffsetRangeExt as _, ToOffset as _, ToPoint,
  24    text_diff,
  25};
  26use language::{BufferSnapshot, EditPreview};
  27use language_model::{LlmApiToken, RefreshLlmTokenListener};
  28use project::Project;
  29use release_channel::AppVersion;
  30use std::borrow::Cow;
  31use std::cmp;
  32use std::collections::{HashMap, VecDeque, hash_map};
  33use std::path::PathBuf;
  34use std::str::FromStr as _;
  35use std::time::{Duration, Instant};
  36use std::{ops::Range, sync::Arc};
  37use thiserror::Error;
  38use util::{ResultExt as _, some_or_debug_panic};
  39use uuid::Uuid;
  40use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  41
  42const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  43
  44/// Maximum number of events to track.
  45const MAX_EVENT_COUNT: usize = 16;
  46
  47pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
  48    max_bytes: 512,
  49    min_bytes: 128,
  50    target_before_cursor_over_total_bytes: 0.5,
  51};
  52
  53pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  54    excerpt: DEFAULT_EXCERPT_OPTIONS,
  55    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
  56    max_diagnostic_bytes: 2048,
  57};
  58
  59#[derive(Clone)]
  60struct ZetaGlobal(Entity<Zeta>);
  61
  62impl Global for ZetaGlobal {}
  63
  64pub struct Zeta {
  65    client: Arc<Client>,
  66    user_store: Entity<UserStore>,
  67    llm_token: LlmApiToken,
  68    _llm_token_subscription: Subscription,
  69    projects: HashMap<EntityId, ZetaProject>,
  70    options: ZetaOptions,
  71    update_required: bool,
  72    debug_tx: Option<mpsc::UnboundedSender<Result<PredictionDebugInfo, String>>>,
  73}
  74
  75#[derive(Debug, Clone, PartialEq)]
  76pub struct ZetaOptions {
  77    pub excerpt: EditPredictionExcerptOptions,
  78    pub max_prompt_bytes: usize,
  79    pub max_diagnostic_bytes: usize,
  80}
  81
  82pub struct PredictionDebugInfo {
  83    pub context: EditPredictionContext,
  84    pub retrieval_time: TimeDelta,
  85    pub request: RequestDebugInfo,
  86    pub buffer: WeakEntity<Buffer>,
  87    pub position: language::Anchor,
  88}
  89
  90pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
  91
  92struct ZetaProject {
  93    syntax_index: Entity<SyntaxIndex>,
  94    events: VecDeque<Event>,
  95    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
  96}
  97
  98struct RegisteredBuffer {
  99    snapshot: BufferSnapshot,
 100    _subscriptions: [gpui::Subscription; 2],
 101}
 102
 103#[derive(Clone)]
 104pub enum Event {
 105    BufferChange {
 106        old_snapshot: BufferSnapshot,
 107        new_snapshot: BufferSnapshot,
 108        timestamp: Instant,
 109    },
 110}
 111
 112impl Zeta {
 113    pub fn global(
 114        client: &Arc<Client>,
 115        user_store: &Entity<UserStore>,
 116        cx: &mut App,
 117    ) -> Entity<Self> {
 118        cx.try_global::<ZetaGlobal>()
 119            .map(|global| global.0.clone())
 120            .unwrap_or_else(|| {
 121                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 122                cx.set_global(ZetaGlobal(zeta.clone()));
 123                zeta
 124            })
 125    }
 126
 127    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 128        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 129
 130        Self {
 131            projects: HashMap::new(),
 132            client,
 133            user_store,
 134            options: DEFAULT_OPTIONS,
 135            llm_token: LlmApiToken::default(),
 136            _llm_token_subscription: cx.subscribe(
 137                &refresh_llm_token_listener,
 138                |this, _listener, _event, cx| {
 139                    let client = this.client.clone();
 140                    let llm_token = this.llm_token.clone();
 141                    cx.spawn(async move |_this, _cx| {
 142                        llm_token.refresh(&client).await?;
 143                        anyhow::Ok(())
 144                    })
 145                    .detach_and_log_err(cx);
 146                },
 147            ),
 148            update_required: false,
 149            debug_tx: None,
 150        }
 151    }
 152
 153    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<Result<PredictionDebugInfo, String>> {
 154        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 155        self.debug_tx = Some(debug_watch_tx);
 156        debug_watch_rx
 157    }
 158
 159    pub fn options(&self) -> &ZetaOptions {
 160        &self.options
 161    }
 162
 163    pub fn set_options(&mut self, options: ZetaOptions) {
 164        self.options = options;
 165    }
 166
 167    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 168        self.user_store.read(cx).edit_prediction_usage()
 169    }
 170
 171    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 172        self.get_or_init_zeta_project(project, cx);
 173    }
 174
 175    pub fn register_buffer(
 176        &mut self,
 177        buffer: &Entity<Buffer>,
 178        project: &Entity<Project>,
 179        cx: &mut Context<Self>,
 180    ) {
 181        let zeta_project = self.get_or_init_zeta_project(project, cx);
 182        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 183    }
 184
 185    fn get_or_init_zeta_project(
 186        &mut self,
 187        project: &Entity<Project>,
 188        cx: &mut App,
 189    ) -> &mut ZetaProject {
 190        self.projects
 191            .entry(project.entity_id())
 192            .or_insert_with(|| ZetaProject {
 193                syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
 194                events: VecDeque::new(),
 195                registered_buffers: HashMap::new(),
 196            })
 197    }
 198
 199    fn register_buffer_impl<'a>(
 200        zeta_project: &'a mut ZetaProject,
 201        buffer: &Entity<Buffer>,
 202        project: &Entity<Project>,
 203        cx: &mut Context<Self>,
 204    ) -> &'a mut RegisteredBuffer {
 205        let buffer_id = buffer.entity_id();
 206        match zeta_project.registered_buffers.entry(buffer_id) {
 207            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 208            hash_map::Entry::Vacant(entry) => {
 209                let snapshot = buffer.read(cx).snapshot();
 210                let project_entity_id = project.entity_id();
 211                entry.insert(RegisteredBuffer {
 212                    snapshot,
 213                    _subscriptions: [
 214                        cx.subscribe(buffer, {
 215                            let project = project.downgrade();
 216                            move |this, buffer, event, cx| {
 217                                if let language::BufferEvent::Edited = event
 218                                    && let Some(project) = project.upgrade()
 219                                {
 220                                    this.report_changes_for_buffer(&buffer, &project, cx);
 221                                }
 222                            }
 223                        }),
 224                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 225                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 226                            else {
 227                                return;
 228                            };
 229                            zeta_project.registered_buffers.remove(&buffer_id);
 230                        }),
 231                    ],
 232                })
 233            }
 234        }
 235    }
 236
 237    fn report_changes_for_buffer(
 238        &mut self,
 239        buffer: &Entity<Buffer>,
 240        project: &Entity<Project>,
 241        cx: &mut Context<Self>,
 242    ) -> BufferSnapshot {
 243        let zeta_project = self.get_or_init_zeta_project(project, cx);
 244        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 245
 246        let new_snapshot = buffer.read(cx).snapshot();
 247        if new_snapshot.version != registered_buffer.snapshot.version {
 248            let old_snapshot =
 249                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 250            Self::push_event(
 251                zeta_project,
 252                Event::BufferChange {
 253                    old_snapshot,
 254                    new_snapshot: new_snapshot.clone(),
 255                    timestamp: Instant::now(),
 256                },
 257            );
 258        }
 259
 260        new_snapshot
 261    }
 262
 263    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
 264        let events = &mut zeta_project.events;
 265
 266        if let Some(Event::BufferChange {
 267            new_snapshot: last_new_snapshot,
 268            timestamp: last_timestamp,
 269            ..
 270        }) = events.back_mut()
 271        {
 272            // Coalesce edits for the same buffer when they happen one after the other.
 273            let Event::BufferChange {
 274                old_snapshot,
 275                new_snapshot,
 276                timestamp,
 277            } = &event;
 278
 279            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 280                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 281                && old_snapshot.version == last_new_snapshot.version
 282            {
 283                *last_new_snapshot = new_snapshot.clone();
 284                *last_timestamp = *timestamp;
 285                return;
 286            }
 287        }
 288
 289        if events.len() >= MAX_EVENT_COUNT {
 290            // These are halved instead of popping to improve prompt caching.
 291            events.drain(..MAX_EVENT_COUNT / 2);
 292        }
 293
 294        events.push_back(event);
 295    }
 296
 297    pub fn request_prediction(
 298        &mut self,
 299        project: &Entity<Project>,
 300        buffer: &Entity<Buffer>,
 301        position: language::Anchor,
 302        cx: &mut Context<Self>,
 303    ) -> Task<Result<Option<EditPrediction>>> {
 304        let project_state = self.projects.get(&project.entity_id());
 305
 306        let index_state = project_state.map(|state| {
 307            state
 308                .syntax_index
 309                .read_with(cx, |index, _cx| index.state().clone())
 310        });
 311        let options = self.options.clone();
 312        let snapshot = buffer.read(cx).snapshot();
 313        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
 314            return Task::ready(Err(anyhow!("No file path for excerpt")));
 315        };
 316        let client = self.client.clone();
 317        let llm_token = self.llm_token.clone();
 318        let app_version = AppVersion::global(cx);
 319        let worktree_snapshots = project
 320            .read(cx)
 321            .worktrees(cx)
 322            .map(|worktree| worktree.read(cx).snapshot())
 323            .collect::<Vec<_>>();
 324        let debug_tx = self.debug_tx.clone();
 325
 326        let events = project_state
 327            .map(|state| {
 328                state
 329                    .events
 330                    .iter()
 331                    .map(|event| match event {
 332                        Event::BufferChange {
 333                            old_snapshot,
 334                            new_snapshot,
 335                            ..
 336                        } => {
 337                            let path = new_snapshot.file().map(|f| f.path().to_path_buf());
 338
 339                            let old_path = old_snapshot.file().and_then(|f| {
 340                                let old_path = f.path().as_ref();
 341                                if Some(old_path) != path.as_deref() {
 342                                    Some(old_path.to_path_buf())
 343                                } else {
 344                                    None
 345                                }
 346                            });
 347
 348                            predict_edits_v3::Event::BufferChange {
 349                                old_path,
 350                                path,
 351                                diff: language::unified_diff(
 352                                    &old_snapshot.text(),
 353                                    &new_snapshot.text(),
 354                                ),
 355                                //todo: Actually detect if this edit was predicted or not
 356                                predicted: false,
 357                            }
 358                        }
 359                    })
 360                    .collect::<Vec<_>>()
 361            })
 362            .unwrap_or_default();
 363
 364        let diagnostics = snapshot.diagnostic_sets().clone();
 365
 366        let request_task = cx.background_spawn({
 367            let snapshot = snapshot.clone();
 368            let buffer = buffer.clone();
 369            async move {
 370                let index_state = if let Some(index_state) = index_state {
 371                    Some(index_state.lock_owned().await)
 372                } else {
 373                    None
 374                };
 375
 376                let cursor_offset = position.to_offset(&snapshot);
 377                let cursor_point = cursor_offset.to_point(&snapshot);
 378
 379                let before_retrieval = chrono::Utc::now();
 380
 381                let Some(context) = EditPredictionContext::gather_context(
 382                    cursor_point,
 383                    &snapshot,
 384                    &options.excerpt,
 385                    index_state.as_deref(),
 386                ) else {
 387                    return Ok(None);
 388                };
 389
 390                let debug_context = if let Some(debug_tx) = debug_tx {
 391                    Some((debug_tx, context.clone()))
 392                } else {
 393                    None
 394                };
 395
 396                let (diagnostic_groups, diagnostic_groups_truncated) =
 397                    Self::gather_nearby_diagnostics(
 398                        cursor_offset,
 399                        &diagnostics,
 400                        &snapshot,
 401                        options.max_diagnostic_bytes,
 402                    );
 403
 404                let request = make_cloud_request(
 405                    excerpt_path.clone(),
 406                    context,
 407                    events,
 408                    // TODO data collection
 409                    false,
 410                    diagnostic_groups,
 411                    diagnostic_groups_truncated,
 412                    None,
 413                    debug_context.is_some(),
 414                    &worktree_snapshots,
 415                    index_state.as_deref(),
 416                    Some(options.max_prompt_bytes),
 417                );
 418
 419                let retrieval_time = chrono::Utc::now() - before_retrieval;
 420                let response = Self::perform_request(client, llm_token, app_version, request).await;
 421
 422                if let Some((debug_tx, context)) = debug_context {
 423                    debug_tx
 424                        .unbounded_send(response.as_ref().map_err(|err| err.to_string()).and_then(
 425                            |response| {
 426                                let Some(request) =
 427                                    some_or_debug_panic(response.0.debug_info.clone())
 428                                else {
 429                                    return Err("Missing debug info".to_string());
 430                                };
 431                                Ok(PredictionDebugInfo {
 432                                    context,
 433                                    request,
 434                                    retrieval_time,
 435                                    buffer: buffer.downgrade(),
 436                                    position,
 437                                })
 438                            },
 439                        ))
 440                        .ok();
 441                }
 442
 443                let (response, usage) = response?;
 444                let edits = Self::compute_edits(&response.edits, &snapshot);
 445
 446                anyhow::Ok(Some((response.request_id, edits, usage)))
 447            }
 448        });
 449
 450        let buffer = buffer.clone();
 451
 452        cx.spawn(async move |this, cx| {
 453            match request_task.await {
 454                Ok(Some((id, edits, usage))) => {
 455                    if let Some(usage) = usage {
 456                        this.update(cx, |this, cx| {
 457                            this.user_store.update(cx, |user_store, cx| {
 458                                user_store.update_edit_prediction_usage(usage, cx);
 459                            });
 460                        })
 461                        .ok();
 462                    }
 463
 464                    // TODO telemetry: duration, etc
 465                    let Some((edits, snapshot, edit_preview_task)) =
 466                        buffer.read_with(cx, |buffer, cx| {
 467                            let new_snapshot = buffer.snapshot();
 468                            let edits: Arc<[_]> =
 469                                interpolate(&snapshot, &new_snapshot, edits)?.into();
 470                            Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 471                        })?
 472                    else {
 473                        return Ok(None);
 474                    };
 475
 476                    Ok(Some(EditPrediction {
 477                        id: EditPredictionId(id),
 478                        edits,
 479                        snapshot,
 480                        edit_preview: edit_preview_task.await,
 481                    }))
 482                }
 483                Ok(None) => Ok(None),
 484                Err(err) => {
 485                    if err.is::<ZedUpdateRequiredError>() {
 486                        cx.update(|cx| {
 487                            this.update(cx, |this, _cx| {
 488                                this.update_required = true;
 489                            })
 490                            .ok();
 491
 492                            let error_message: SharedString = err.to_string().into();
 493                            show_app_notification(
 494                                NotificationId::unique::<ZedUpdateRequiredError>(),
 495                                cx,
 496                                move |cx| {
 497                                    cx.new(|cx| {
 498                                        ErrorMessagePrompt::new(error_message.clone(), cx)
 499                                            .with_link_button(
 500                                                "Update Zed",
 501                                                "https://zed.dev/releases",
 502                                            )
 503                                    })
 504                                },
 505                            );
 506                        })
 507                        .ok();
 508                    }
 509
 510                    Err(err)
 511                }
 512            }
 513        })
 514    }
 515
 516    async fn perform_request(
 517        client: Arc<Client>,
 518        llm_token: LlmApiToken,
 519        app_version: SemanticVersion,
 520        request: predict_edits_v3::PredictEditsRequest,
 521    ) -> Result<(
 522        predict_edits_v3::PredictEditsResponse,
 523        Option<EditPredictionUsage>,
 524    )> {
 525        let http_client = client.http_client();
 526        let mut token = llm_token.acquire(&client).await?;
 527        let mut did_retry = false;
 528
 529        loop {
 530            let request_builder = http_client::Request::builder().method(Method::POST);
 531            let request_builder =
 532                if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 533                    request_builder.uri(predict_edits_url)
 534                } else {
 535                    request_builder.uri(
 536                        http_client
 537                            .build_zed_llm_url("/predict_edits/v3", &[])?
 538                            .as_ref(),
 539                    )
 540                };
 541            let request = request_builder
 542                .header("Content-Type", "application/json")
 543                .header("Authorization", format!("Bearer {}", token))
 544                .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 545                .body(serde_json::to_string(&request)?.into())?;
 546
 547            let mut response = http_client.send(request).await?;
 548
 549            if let Some(minimum_required_version) = response
 550                .headers()
 551                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 552                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 553            {
 554                anyhow::ensure!(
 555                    app_version >= minimum_required_version,
 556                    ZedUpdateRequiredError {
 557                        minimum_version: minimum_required_version
 558                    }
 559                );
 560            }
 561
 562            if response.status().is_success() {
 563                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 564
 565                let mut body = Vec::new();
 566                response.body_mut().read_to_end(&mut body).await?;
 567                return Ok((serde_json::from_slice(&body)?, usage));
 568            } else if !did_retry
 569                && response
 570                    .headers()
 571                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 572                    .is_some()
 573            {
 574                did_retry = true;
 575                token = llm_token.refresh(&client).await?;
 576            } else {
 577                let mut body = String::new();
 578                response.body_mut().read_to_string(&mut body).await?;
 579                anyhow::bail!(
 580                    "error predicting edits.\nStatus: {:?}\nBody: {}",
 581                    response.status(),
 582                    body
 583                );
 584            }
 585        }
 586    }
 587
 588    fn compute_edits(
 589        edits: &[predict_edits_v3::Edit],
 590        snapshot: &BufferSnapshot,
 591    ) -> Arc<[(Range<Anchor>, String)]> {
 592        edits
 593            .iter()
 594            .flat_map(|edit| {
 595                // TODO multi-file edits
 596                let old_text = snapshot.text_for_range(edit.range.clone());
 597
 598                Self::compute_excerpt_edits(
 599                    old_text.collect::<Cow<str>>(),
 600                    &edit.content,
 601                    edit.range.start,
 602                    &snapshot,
 603                )
 604            })
 605            .collect::<Vec<_>>()
 606            .into()
 607    }
 608
 609    fn compute_excerpt_edits(
 610        old_text: Cow<str>,
 611        new_text: &str,
 612        offset: usize,
 613        snapshot: &BufferSnapshot,
 614    ) -> impl Iterator<Item = (Range<Anchor>, String)> {
 615        text_diff(&old_text, new_text)
 616            .into_iter()
 617            .map(move |(mut old_range, new_text)| {
 618                old_range.start += offset;
 619                old_range.end += offset;
 620
 621                let prefix_len = common_prefix(
 622                    snapshot.chars_for_range(old_range.clone()),
 623                    new_text.chars(),
 624                );
 625                old_range.start += prefix_len;
 626
 627                let suffix_len = common_prefix(
 628                    snapshot.reversed_chars_for_range(old_range.clone()),
 629                    new_text[prefix_len..].chars().rev(),
 630                );
 631                old_range.end = old_range.end.saturating_sub(suffix_len);
 632
 633                let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
 634                let range = if old_range.is_empty() {
 635                    let anchor = snapshot.anchor_after(old_range.start);
 636                    anchor..anchor
 637                } else {
 638                    snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end)
 639                };
 640                (range, new_text)
 641            })
 642    }
 643
 644    fn gather_nearby_diagnostics(
 645        cursor_offset: usize,
 646        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
 647        snapshot: &BufferSnapshot,
 648        max_diagnostics_bytes: usize,
 649    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
 650        // TODO: Could make this more efficient
 651        let mut diagnostic_groups = Vec::new();
 652        for (language_server_id, diagnostics) in diagnostic_sets {
 653            let mut groups = Vec::new();
 654            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
 655            diagnostic_groups.extend(
 656                groups
 657                    .into_iter()
 658                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
 659            );
 660        }
 661
 662        // sort by proximity to cursor
 663        diagnostic_groups.sort_by_key(|group| {
 664            let range = &group.entries[group.primary_ix].range;
 665            if range.start >= cursor_offset {
 666                range.start - cursor_offset
 667            } else if cursor_offset >= range.end {
 668                cursor_offset - range.end
 669            } else {
 670                (cursor_offset - range.start).min(range.end - cursor_offset)
 671            }
 672        });
 673
 674        let mut results = Vec::new();
 675        let mut diagnostic_groups_truncated = false;
 676        let mut diagnostics_byte_count = 0;
 677        for group in diagnostic_groups {
 678            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
 679            diagnostics_byte_count += raw_value.get().len();
 680            if diagnostics_byte_count > max_diagnostics_bytes {
 681                diagnostic_groups_truncated = true;
 682                break;
 683            }
 684            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
 685        }
 686
 687        (results, diagnostic_groups_truncated)
 688    }
 689
 690    // TODO: Dedupe with similar code in request_prediction?
 691    pub fn cloud_request_for_zeta_cli(
 692        &mut self,
 693        project: &Entity<Project>,
 694        buffer: &Entity<Buffer>,
 695        position: language::Anchor,
 696        cx: &mut Context<Self>,
 697    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
 698        let project_state = self.projects.get(&project.entity_id());
 699
 700        let index_state = project_state.map(|state| {
 701            state
 702                .syntax_index
 703                .read_with(cx, |index, _cx| index.state().clone())
 704        });
 705        let options = self.options.clone();
 706        let snapshot = buffer.read(cx).snapshot();
 707        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
 708            return Task::ready(Err(anyhow!("No file path for excerpt")));
 709        };
 710        let worktree_snapshots = project
 711            .read(cx)
 712            .worktrees(cx)
 713            .map(|worktree| worktree.read(cx).snapshot())
 714            .collect::<Vec<_>>();
 715
 716        cx.background_spawn(async move {
 717            let index_state = if let Some(index_state) = index_state {
 718                Some(index_state.lock_owned().await)
 719            } else {
 720                None
 721            };
 722
 723            let cursor_point = position.to_point(&snapshot);
 724
 725            let debug_info = true;
 726            EditPredictionContext::gather_context(
 727                cursor_point,
 728                &snapshot,
 729                &options.excerpt,
 730                index_state.as_deref(),
 731            )
 732            .context("Failed to select excerpt")
 733            .map(|context| {
 734                make_cloud_request(
 735                    excerpt_path.clone(),
 736                    context,
 737                    // TODO pass everything
 738                    Vec::new(),
 739                    false,
 740                    Vec::new(),
 741                    false,
 742                    None,
 743                    debug_info,
 744                    &worktree_snapshots,
 745                    index_state.as_deref(),
 746                    Some(options.max_prompt_bytes),
 747                )
 748            })
 749        })
 750    }
 751}
 752
 753fn common_prefix<T1: Iterator<Item = char>, T2: Iterator<Item = char>>(a: T1, b: T2) -> usize {
 754    a.zip(b)
 755        .take_while(|(a, b)| a == b)
 756        .map(|(a, _)| a.len_utf8())
 757        .sum()
 758}
 759
 760#[derive(Error, Debug)]
 761#[error(
 762    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
 763)]
 764pub struct ZedUpdateRequiredError {
 765    minimum_version: SemanticVersion,
 766}
 767
 768pub struct ZetaEditPredictionProvider {
 769    zeta: Entity<Zeta>,
 770    current_prediction: Option<CurrentEditPrediction>,
 771    next_pending_prediction_id: usize,
 772    pending_predictions: ArrayVec<PendingPrediction, 2>,
 773    last_request_timestamp: Instant,
 774}
 775
 776impl ZetaEditPredictionProvider {
 777    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
 778
 779    pub fn new(
 780        project: Option<&Entity<Project>>,
 781        client: &Arc<Client>,
 782        user_store: &Entity<UserStore>,
 783        cx: &mut App,
 784    ) -> Self {
 785        let zeta = Zeta::global(client, user_store, cx);
 786        if let Some(project) = project {
 787            zeta.update(cx, |zeta, cx| {
 788                zeta.register_project(project, cx);
 789            });
 790        }
 791
 792        Self {
 793            zeta,
 794            current_prediction: None,
 795            next_pending_prediction_id: 0,
 796            pending_predictions: ArrayVec::new(),
 797            last_request_timestamp: Instant::now(),
 798        }
 799    }
 800}
 801
 802#[derive(Clone)]
 803struct CurrentEditPrediction {
 804    buffer_id: EntityId,
 805    prediction: EditPrediction,
 806}
 807
 808impl CurrentEditPrediction {
 809    fn should_replace_prediction(&self, old_prediction: &Self, snapshot: &BufferSnapshot) -> bool {
 810        if self.buffer_id != old_prediction.buffer_id {
 811            return true;
 812        }
 813
 814        let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
 815            return true;
 816        };
 817        let Some(new_edits) = self.prediction.interpolate(snapshot) else {
 818            return false;
 819        };
 820
 821        if old_edits.len() == 1 && new_edits.len() == 1 {
 822            let (old_range, old_text) = &old_edits[0];
 823            let (new_range, new_text) = &new_edits[0];
 824            new_range == old_range && new_text.starts_with(old_text)
 825        } else {
 826            true
 827        }
 828    }
 829}
 830
 831#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
 832pub struct EditPredictionId(Uuid);
 833
 834impl From<EditPredictionId> for gpui::ElementId {
 835    fn from(value: EditPredictionId) -> Self {
 836        gpui::ElementId::Uuid(value.0)
 837    }
 838}
 839
 840impl std::fmt::Display for EditPredictionId {
 841    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 842        write!(f, "{}", self.0)
 843    }
 844}
 845
 846#[derive(Clone)]
 847pub struct EditPrediction {
 848    id: EditPredictionId,
 849    edits: Arc<[(Range<Anchor>, String)]>,
 850    snapshot: BufferSnapshot,
 851    edit_preview: EditPreview,
 852}
 853
 854impl EditPrediction {
 855    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
 856        interpolate(&self.snapshot, new_snapshot, self.edits.clone())
 857    }
 858}
 859
 860struct PendingPrediction {
 861    id: usize,
 862    _task: Task<()>,
 863}
 864
 865impl EditPredictionProvider for ZetaEditPredictionProvider {
 866    fn name() -> &'static str {
 867        "zed-predict2"
 868    }
 869
 870    fn display_name() -> &'static str {
 871        "Zed's Edit Predictions 2"
 872    }
 873
 874    fn show_completions_in_menu() -> bool {
 875        true
 876    }
 877
 878    fn show_tab_accept_marker() -> bool {
 879        true
 880    }
 881
 882    fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
 883        // TODO [zeta2]
 884        DataCollectionState::Unsupported
 885    }
 886
 887    fn toggle_data_collection(&mut self, _cx: &mut App) {
 888        // TODO [zeta2]
 889    }
 890
 891    fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
 892        self.zeta.read(cx).usage(cx)
 893    }
 894
 895    fn is_enabled(
 896        &self,
 897        _buffer: &Entity<language::Buffer>,
 898        _cursor_position: language::Anchor,
 899        _cx: &App,
 900    ) -> bool {
 901        true
 902    }
 903
 904    fn is_refreshing(&self) -> bool {
 905        !self.pending_predictions.is_empty()
 906    }
 907
 908    fn refresh(
 909        &mut self,
 910        project: Option<Entity<project::Project>>,
 911        buffer: Entity<language::Buffer>,
 912        cursor_position: language::Anchor,
 913        _debounce: bool,
 914        cx: &mut Context<Self>,
 915    ) {
 916        let Some(project) = project else {
 917            return;
 918        };
 919
 920        if self
 921            .zeta
 922            .read(cx)
 923            .user_store
 924            .read_with(cx, |user_store, _cx| {
 925                user_store.account_too_young() || user_store.has_overdue_invoices()
 926            })
 927        {
 928            return;
 929        }
 930
 931        if let Some(current_prediction) = self.current_prediction.as_ref() {
 932            let snapshot = buffer.read(cx).snapshot();
 933            if current_prediction
 934                .prediction
 935                .interpolate(&snapshot)
 936                .is_some()
 937            {
 938                return;
 939            }
 940        }
 941
 942        let pending_prediction_id = self.next_pending_prediction_id;
 943        self.next_pending_prediction_id += 1;
 944        let last_request_timestamp = self.last_request_timestamp;
 945
 946        let task = cx.spawn(async move |this, cx| {
 947            if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
 948                .checked_duration_since(Instant::now())
 949            {
 950                cx.background_executor().timer(timeout).await;
 951            }
 952
 953            let prediction_request = this.update(cx, |this, cx| {
 954                this.last_request_timestamp = Instant::now();
 955                this.zeta.update(cx, |zeta, cx| {
 956                    zeta.request_prediction(&project, &buffer, cursor_position, cx)
 957                })
 958            });
 959
 960            let prediction = match prediction_request {
 961                Ok(prediction_request) => {
 962                    let prediction_request = prediction_request.await;
 963                    prediction_request.map(|c| {
 964                        c.map(|prediction| CurrentEditPrediction {
 965                            buffer_id: buffer.entity_id(),
 966                            prediction,
 967                        })
 968                    })
 969                }
 970                Err(error) => Err(error),
 971            };
 972
 973            this.update(cx, |this, cx| {
 974                if this.pending_predictions[0].id == pending_prediction_id {
 975                    this.pending_predictions.remove(0);
 976                } else {
 977                    this.pending_predictions.clear();
 978                }
 979
 980                let Some(new_prediction) = prediction
 981                    .context("edit prediction failed")
 982                    .log_err()
 983                    .flatten()
 984                else {
 985                    cx.notify();
 986                    return;
 987                };
 988
 989                if let Some(old_prediction) = this.current_prediction.as_ref() {
 990                    let snapshot = buffer.read(cx).snapshot();
 991                    if new_prediction.should_replace_prediction(old_prediction, &snapshot) {
 992                        this.current_prediction = Some(new_prediction);
 993                    }
 994                } else {
 995                    this.current_prediction = Some(new_prediction);
 996                }
 997
 998                cx.notify();
 999            })
1000            .ok();
1001        });
1002
1003        // We always maintain at most two pending predictions. When we already
1004        // have two, we replace the newest one.
1005        if self.pending_predictions.len() <= 1 {
1006            self.pending_predictions.push(PendingPrediction {
1007                id: pending_prediction_id,
1008                _task: task,
1009            });
1010        } else if self.pending_predictions.len() == 2 {
1011            self.pending_predictions.pop();
1012            self.pending_predictions.push(PendingPrediction {
1013                id: pending_prediction_id,
1014                _task: task,
1015            });
1016        }
1017
1018        cx.notify();
1019    }
1020
1021    fn cycle(
1022        &mut self,
1023        _buffer: Entity<language::Buffer>,
1024        _cursor_position: language::Anchor,
1025        _direction: Direction,
1026        _cx: &mut Context<Self>,
1027    ) {
1028    }
1029
1030    fn accept(&mut self, _cx: &mut Context<Self>) {
1031        // TODO [zeta2] report accept
1032        self.current_prediction.take();
1033        self.pending_predictions.clear();
1034    }
1035
1036    fn discard(&mut self, _cx: &mut Context<Self>) {
1037        self.pending_predictions.clear();
1038        self.current_prediction.take();
1039    }
1040
1041    fn suggest(
1042        &mut self,
1043        buffer: &Entity<language::Buffer>,
1044        cursor_position: language::Anchor,
1045        cx: &mut Context<Self>,
1046    ) -> Option<edit_prediction::EditPrediction> {
1047        let CurrentEditPrediction {
1048            buffer_id,
1049            prediction,
1050            ..
1051        } = self.current_prediction.as_mut()?;
1052
1053        // Invalidate previous prediction if it was generated for a different buffer.
1054        if *buffer_id != buffer.entity_id() {
1055            self.current_prediction.take();
1056            return None;
1057        }
1058
1059        let buffer = buffer.read(cx);
1060        let Some(edits) = prediction.interpolate(&buffer.snapshot()) else {
1061            self.current_prediction.take();
1062            return None;
1063        };
1064
1065        let cursor_row = cursor_position.to_point(buffer).row;
1066        let (closest_edit_ix, (closest_edit_range, _)) =
1067            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1068                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1069                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1070                cmp::min(distance_from_start, distance_from_end)
1071            })?;
1072
1073        let mut edit_start_ix = closest_edit_ix;
1074        for (range, _) in edits[..edit_start_ix].iter().rev() {
1075            let distance_from_closest_edit =
1076                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1077            if distance_from_closest_edit <= 1 {
1078                edit_start_ix -= 1;
1079            } else {
1080                break;
1081            }
1082        }
1083
1084        let mut edit_end_ix = closest_edit_ix + 1;
1085        for (range, _) in &edits[edit_end_ix..] {
1086            let distance_from_closest_edit =
1087                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1088            if distance_from_closest_edit <= 1 {
1089                edit_end_ix += 1;
1090            } else {
1091                break;
1092            }
1093        }
1094
1095        Some(edit_prediction::EditPrediction {
1096            id: Some(prediction.id.to_string().into()),
1097            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1098            edit_preview: Some(prediction.edit_preview.clone()),
1099        })
1100    }
1101}
1102
1103fn make_cloud_request(
1104    excerpt_path: PathBuf,
1105    context: EditPredictionContext,
1106    events: Vec<predict_edits_v3::Event>,
1107    can_collect_data: bool,
1108    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1109    diagnostic_groups_truncated: bool,
1110    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1111    debug_info: bool,
1112    worktrees: &Vec<worktree::Snapshot>,
1113    index_state: Option<&SyntaxIndexState>,
1114    prompt_max_bytes: Option<usize>,
1115) -> predict_edits_v3::PredictEditsRequest {
1116    let mut signatures = Vec::new();
1117    let mut declaration_to_signature_index = HashMap::default();
1118    let mut referenced_declarations = Vec::new();
1119
1120    for snippet in context.snippets {
1121        let project_entry_id = snippet.declaration.project_entry_id();
1122        let Some(path) = worktrees.iter().find_map(|worktree| {
1123            worktree.entry_for_id(project_entry_id).map(|entry| {
1124                let mut full_path = PathBuf::new();
1125                full_path.push(worktree.root_name());
1126                full_path.push(&entry.path);
1127                full_path
1128            })
1129        }) else {
1130            continue;
1131        };
1132
1133        let parent_index = index_state.and_then(|index_state| {
1134            snippet.declaration.parent().and_then(|parent| {
1135                add_signature(
1136                    parent,
1137                    &mut declaration_to_signature_index,
1138                    &mut signatures,
1139                    index_state,
1140                )
1141            })
1142        });
1143
1144        let (text, text_is_truncated) = snippet.declaration.item_text();
1145        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1146            path,
1147            text: text.into(),
1148            range: snippet.declaration.item_range(),
1149            text_is_truncated,
1150            signature_range: snippet.declaration.signature_range_in_item_text(),
1151            parent_index,
1152            score_components: snippet.score_components,
1153            signature_score: snippet.scores.signature,
1154            declaration_score: snippet.scores.declaration,
1155        });
1156    }
1157
1158    let excerpt_parent = index_state.and_then(|index_state| {
1159        context
1160            .excerpt
1161            .parent_declarations
1162            .last()
1163            .and_then(|(parent, _)| {
1164                add_signature(
1165                    *parent,
1166                    &mut declaration_to_signature_index,
1167                    &mut signatures,
1168                    index_state,
1169                )
1170            })
1171    });
1172
1173    predict_edits_v3::PredictEditsRequest {
1174        excerpt_path,
1175        excerpt: context.excerpt_text.body,
1176        excerpt_range: context.excerpt.range,
1177        cursor_offset: context.cursor_offset_in_excerpt,
1178        referenced_declarations,
1179        signatures,
1180        excerpt_parent,
1181        events,
1182        can_collect_data,
1183        diagnostic_groups,
1184        diagnostic_groups_truncated,
1185        git_info,
1186        debug_info,
1187        prompt_max_bytes,
1188    }
1189}
1190
1191fn add_signature(
1192    declaration_id: DeclarationId,
1193    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1194    signatures: &mut Vec<Signature>,
1195    index: &SyntaxIndexState,
1196) -> Option<usize> {
1197    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1198        return Some(*signature_index);
1199    }
1200    let Some(parent_declaration) = index.declaration(declaration_id) else {
1201        log::error!("bug: missing parent declaration");
1202        return None;
1203    };
1204    let parent_index = parent_declaration.parent().and_then(|parent| {
1205        add_signature(parent, declaration_to_signature_index, signatures, index)
1206    });
1207    let (text, text_is_truncated) = parent_declaration.signature_text();
1208    let signature_index = signatures.len();
1209    signatures.push(Signature {
1210        text: text.into(),
1211        text_is_truncated,
1212        parent_index,
1213        range: parent_declaration.signature_range(),
1214    });
1215    declaration_to_signature_index.insert(declaration_id, signature_index);
1216    Some(signature_index)
1217}
1218
1219fn interpolate(
1220    old_snapshot: &BufferSnapshot,
1221    new_snapshot: &BufferSnapshot,
1222    current_edits: Arc<[(Range<Anchor>, String)]>,
1223) -> Option<Vec<(Range<Anchor>, String)>> {
1224    let mut edits = Vec::new();
1225
1226    let mut model_edits = current_edits.iter().peekable();
1227    for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
1228        while let Some((model_old_range, _)) = model_edits.peek() {
1229            let model_old_range = model_old_range.to_offset(old_snapshot);
1230            if model_old_range.end < user_edit.old.start {
1231                let (model_old_range, model_new_text) = model_edits.next().unwrap();
1232                edits.push((model_old_range.clone(), model_new_text.clone()));
1233            } else {
1234                break;
1235            }
1236        }
1237
1238        if let Some((model_old_range, model_new_text)) = model_edits.peek() {
1239            let model_old_offset_range = model_old_range.to_offset(old_snapshot);
1240            if user_edit.old == model_old_offset_range {
1241                let user_new_text = new_snapshot
1242                    .text_for_range(user_edit.new.clone())
1243                    .collect::<String>();
1244
1245                if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
1246                    if !model_suffix.is_empty() {
1247                        let anchor = old_snapshot.anchor_after(user_edit.old.end);
1248                        edits.push((anchor..anchor, model_suffix.to_string()));
1249                    }
1250
1251                    model_edits.next();
1252                    continue;
1253                }
1254            }
1255        }
1256
1257        return None;
1258    }
1259
1260    edits.extend(model_edits.cloned());
1261
1262    if edits.is_empty() { None } else { Some(edits) }
1263}
1264
1265#[cfg(test)]
1266mod tests {
1267    use super::*;
1268    use gpui::TestAppContext;
1269    use indoc::indoc;
1270
1271    #[gpui::test]
1272    async fn test_compute_edits(cx: &mut TestAppContext) {
1273        let old = indoc! {r#"
1274            fn main() {
1275                let args =
1276                println!("{}", args[1])
1277            }
1278        "#};
1279
1280        let new = indoc! {r#"
1281            fn main() {
1282                let args = std::env::args();
1283                println!("{}", args[1]);
1284            }
1285        "#};
1286
1287        let buffer = cx.new(|cx| Buffer::local(old, cx));
1288        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot());
1289
1290        // TODO cover more cases when multi-file is supported
1291        let big_edits = vec![predict_edits_v3::Edit {
1292            path: PathBuf::from("test.txt"),
1293            range: 0..old.len(),
1294            content: new.into(),
1295        }];
1296
1297        let edits = Zeta::compute_edits(&big_edits, &snapshot);
1298        assert_eq!(edits.len(), 2);
1299        assert_eq!(
1300            edits[0].0.to_point(&snapshot).start,
1301            language::Point::new(1, 14)
1302        );
1303        assert_eq!(edits[0].1, " std::env::args();");
1304        assert_eq!(
1305            edits[1].0.to_point(&snapshot).start,
1306            language::Point::new(2, 27)
1307        );
1308        assert_eq!(edits[1].1, ";");
1309    }
1310
1311    #[gpui::test]
1312    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1313        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1314        let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1315            to_prediction_edits(
1316                [(2..5, "REM".to_string()), (9..11, "".to_string())],
1317                &buffer,
1318                cx,
1319            )
1320            .into()
1321        });
1322
1323        let edit_preview = cx
1324            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1325            .await;
1326
1327        let prediction = EditPrediction {
1328            id: EditPredictionId(Uuid::new_v4()),
1329            edits,
1330            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1331            edit_preview,
1332        };
1333
1334        cx.update(|cx| {
1335            assert_eq!(
1336                from_prediction_edits(
1337                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1338                    &buffer,
1339                    cx
1340                ),
1341                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1342            );
1343
1344            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1345            assert_eq!(
1346                from_prediction_edits(
1347                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1348                    &buffer,
1349                    cx
1350                ),
1351                vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1352            );
1353
1354            buffer.update(cx, |buffer, cx| buffer.undo(cx));
1355            assert_eq!(
1356                from_prediction_edits(
1357                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1358                    &buffer,
1359                    cx
1360                ),
1361                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1362            );
1363
1364            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1365            assert_eq!(
1366                from_prediction_edits(
1367                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1368                    &buffer,
1369                    cx
1370                ),
1371                vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1372            );
1373
1374            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1375            assert_eq!(
1376                from_prediction_edits(
1377                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1378                    &buffer,
1379                    cx
1380                ),
1381                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1382            );
1383
1384            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1385            assert_eq!(
1386                from_prediction_edits(
1387                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1388                    &buffer,
1389                    cx
1390                ),
1391                vec![(9..11, "".to_string())]
1392            );
1393
1394            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1395            assert_eq!(
1396                from_prediction_edits(
1397                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1398                    &buffer,
1399                    cx
1400                ),
1401                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1402            );
1403
1404            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1405            assert_eq!(
1406                from_prediction_edits(
1407                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1408                    &buffer,
1409                    cx
1410                ),
1411                vec![(4..4, "M".to_string())]
1412            );
1413
1414            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1415            assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1416        })
1417    }
1418
1419    fn to_prediction_edits(
1420        iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1421        buffer: &Entity<Buffer>,
1422        cx: &App,
1423    ) -> Vec<(Range<Anchor>, String)> {
1424        let buffer = buffer.read(cx);
1425        iterator
1426            .into_iter()
1427            .map(|(range, text)| {
1428                (
1429                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1430                    text,
1431                )
1432            })
1433            .collect()
1434    }
1435
1436    fn from_prediction_edits(
1437        editor_edits: &[(Range<Anchor>, String)],
1438        buffer: &Entity<Buffer>,
1439        cx: &App,
1440    ) -> Vec<(Range<usize>, String)> {
1441        let buffer = buffer.read(cx);
1442        editor_edits
1443            .iter()
1444            .map(|(range, text)| {
1445                (
1446                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
1447                    text.clone(),
1448                )
1449            })
1450            .collect()
1451    }
1452}