zeta2.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use arrayvec::ArrayVec;
   3use chrono::TimeDelta;
   4use client::{Client, EditPredictionUsage, UserStore};
   5use cloud_llm_client::predict_edits_v3::{self, Signature};
   6use cloud_llm_client::{
   7    EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
   8};
   9use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
  10use edit_prediction_context::{
  11    DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
  12    SyntaxIndexState,
  13};
  14use futures::AsyncReadExt as _;
  15use futures::channel::mpsc;
  16use gpui::http_client::Method;
  17use gpui::{
  18    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  19    http_client, prelude::*,
  20};
  21use language::{Anchor, Buffer, OffsetRangeExt as _, ToPoint};
  22use language::{BufferSnapshot, EditPreview};
  23use language_model::{LlmApiToken, RefreshLlmTokenListener};
  24use project::Project;
  25use release_channel::AppVersion;
  26use std::cmp;
  27use std::collections::{HashMap, VecDeque, hash_map};
  28use std::path::PathBuf;
  29use std::str::FromStr as _;
  30use std::time::{Duration, Instant};
  31use std::{ops::Range, sync::Arc};
  32use thiserror::Error;
  33use util::{ResultExt as _, some_or_debug_panic};
  34use uuid::Uuid;
  35use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  36
  37const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  38
  39/// Maximum number of events to track.
  40const MAX_EVENT_COUNT: usize = 16;
  41
  42pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
  43    max_bytes: 512,
  44    min_bytes: 128,
  45    target_before_cursor_over_total_bytes: 0.5,
  46};
  47
  48#[derive(Clone)]
  49struct ZetaGlobal(Entity<Zeta>);
  50
  51impl Global for ZetaGlobal {}
  52
  53pub struct Zeta {
  54    client: Arc<Client>,
  55    user_store: Entity<UserStore>,
  56    llm_token: LlmApiToken,
  57    _llm_token_subscription: Subscription,
  58    projects: HashMap<EntityId, ZetaProject>,
  59    pub excerpt_options: EditPredictionExcerptOptions,
  60    update_required: bool,
  61    debug_tx: Option<mpsc::UnboundedSender<Result<PredictionDebugInfo, String>>>,
  62}
  63
  64pub struct PredictionDebugInfo {
  65    pub context: EditPredictionContext,
  66    pub retrieval_time: TimeDelta,
  67    pub request: RequestDebugInfo,
  68    pub buffer: WeakEntity<Buffer>,
  69    pub position: language::Anchor,
  70}
  71
  72pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
  73
  74struct ZetaProject {
  75    syntax_index: Entity<SyntaxIndex>,
  76    events: VecDeque<Event>,
  77    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
  78}
  79
  80struct RegisteredBuffer {
  81    snapshot: BufferSnapshot,
  82    _subscriptions: [gpui::Subscription; 2],
  83}
  84
  85#[derive(Clone)]
  86pub enum Event {
  87    BufferChange {
  88        old_snapshot: BufferSnapshot,
  89        new_snapshot: BufferSnapshot,
  90        timestamp: Instant,
  91    },
  92}
  93
  94impl Zeta {
  95    pub fn global(
  96        client: &Arc<Client>,
  97        user_store: &Entity<UserStore>,
  98        cx: &mut App,
  99    ) -> Entity<Self> {
 100        cx.try_global::<ZetaGlobal>()
 101            .map(|global| global.0.clone())
 102            .unwrap_or_else(|| {
 103                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 104                cx.set_global(ZetaGlobal(zeta.clone()));
 105                zeta
 106            })
 107    }
 108
 109    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 110        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 111
 112        Self {
 113            projects: HashMap::new(),
 114            client,
 115            user_store,
 116            excerpt_options: DEFAULT_EXCERPT_OPTIONS,
 117            llm_token: LlmApiToken::default(),
 118            _llm_token_subscription: cx.subscribe(
 119                &refresh_llm_token_listener,
 120                |this, _listener, _event, cx| {
 121                    let client = this.client.clone();
 122                    let llm_token = this.llm_token.clone();
 123                    cx.spawn(async move |_this, _cx| {
 124                        llm_token.refresh(&client).await?;
 125                        anyhow::Ok(())
 126                    })
 127                    .detach_and_log_err(cx);
 128                },
 129            ),
 130            update_required: false,
 131            debug_tx: None,
 132        }
 133    }
 134
 135    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<Result<PredictionDebugInfo, String>> {
 136        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 137        self.debug_tx = Some(debug_watch_tx);
 138        debug_watch_rx
 139    }
 140
 141    pub fn excerpt_options(&self) -> &EditPredictionExcerptOptions {
 142        &self.excerpt_options
 143    }
 144
 145    pub fn set_excerpt_options(&mut self, options: EditPredictionExcerptOptions) {
 146        self.excerpt_options = options;
 147    }
 148
 149    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 150        self.user_store.read(cx).edit_prediction_usage()
 151    }
 152
 153    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 154        self.get_or_init_zeta_project(project, cx);
 155    }
 156
 157    pub fn register_buffer(
 158        &mut self,
 159        buffer: &Entity<Buffer>,
 160        project: &Entity<Project>,
 161        cx: &mut Context<Self>,
 162    ) {
 163        let zeta_project = self.get_or_init_zeta_project(project, cx);
 164        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 165    }
 166
 167    fn get_or_init_zeta_project(
 168        &mut self,
 169        project: &Entity<Project>,
 170        cx: &mut App,
 171    ) -> &mut ZetaProject {
 172        self.projects
 173            .entry(project.entity_id())
 174            .or_insert_with(|| ZetaProject {
 175                syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
 176                events: VecDeque::new(),
 177                registered_buffers: HashMap::new(),
 178            })
 179    }
 180
 181    fn register_buffer_impl<'a>(
 182        zeta_project: &'a mut ZetaProject,
 183        buffer: &Entity<Buffer>,
 184        project: &Entity<Project>,
 185        cx: &mut Context<Self>,
 186    ) -> &'a mut RegisteredBuffer {
 187        let buffer_id = buffer.entity_id();
 188        match zeta_project.registered_buffers.entry(buffer_id) {
 189            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 190            hash_map::Entry::Vacant(entry) => {
 191                let snapshot = buffer.read(cx).snapshot();
 192                let project_entity_id = project.entity_id();
 193                entry.insert(RegisteredBuffer {
 194                    snapshot,
 195                    _subscriptions: [
 196                        cx.subscribe(buffer, {
 197                            let project = project.downgrade();
 198                            move |this, buffer, event, cx| {
 199                                if let language::BufferEvent::Edited = event
 200                                    && let Some(project) = project.upgrade()
 201                                {
 202                                    this.report_changes_for_buffer(&buffer, &project, cx);
 203                                }
 204                            }
 205                        }),
 206                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 207                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 208                            else {
 209                                return;
 210                            };
 211                            zeta_project.registered_buffers.remove(&buffer_id);
 212                        }),
 213                    ],
 214                })
 215            }
 216        }
 217    }
 218
 219    fn report_changes_for_buffer(
 220        &mut self,
 221        buffer: &Entity<Buffer>,
 222        project: &Entity<Project>,
 223        cx: &mut Context<Self>,
 224    ) -> BufferSnapshot {
 225        let zeta_project = self.get_or_init_zeta_project(project, cx);
 226        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 227
 228        let new_snapshot = buffer.read(cx).snapshot();
 229        if new_snapshot.version != registered_buffer.snapshot.version {
 230            let old_snapshot =
 231                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 232            Self::push_event(
 233                zeta_project,
 234                Event::BufferChange {
 235                    old_snapshot,
 236                    new_snapshot: new_snapshot.clone(),
 237                    timestamp: Instant::now(),
 238                },
 239            );
 240        }
 241
 242        new_snapshot
 243    }
 244
 245    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
 246        let events = &mut zeta_project.events;
 247
 248        if let Some(Event::BufferChange {
 249            new_snapshot: last_new_snapshot,
 250            timestamp: last_timestamp,
 251            ..
 252        }) = events.back_mut()
 253        {
 254            // Coalesce edits for the same buffer when they happen one after the other.
 255            let Event::BufferChange {
 256                old_snapshot,
 257                new_snapshot,
 258                timestamp,
 259            } = &event;
 260
 261            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 262                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 263                && old_snapshot.version == last_new_snapshot.version
 264            {
 265                *last_new_snapshot = new_snapshot.clone();
 266                *last_timestamp = *timestamp;
 267                return;
 268            }
 269        }
 270
 271        if events.len() >= MAX_EVENT_COUNT {
 272            // These are halved instead of popping to improve prompt caching.
 273            events.drain(..MAX_EVENT_COUNT / 2);
 274        }
 275
 276        events.push_back(event);
 277    }
 278
 279    pub fn request_prediction(
 280        &mut self,
 281        project: &Entity<Project>,
 282        buffer: &Entity<Buffer>,
 283        position: language::Anchor,
 284        cx: &mut Context<Self>,
 285    ) -> Task<Result<Option<EditPrediction>>> {
 286        let project_state = self.projects.get(&project.entity_id());
 287
 288        let index_state = project_state.map(|state| {
 289            state
 290                .syntax_index
 291                .read_with(cx, |index, _cx| index.state().clone())
 292        });
 293        let excerpt_options = self.excerpt_options.clone();
 294        let snapshot = buffer.read(cx).snapshot();
 295        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
 296            return Task::ready(Err(anyhow!("No file path for excerpt")));
 297        };
 298        let client = self.client.clone();
 299        let llm_token = self.llm_token.clone();
 300        let app_version = AppVersion::global(cx);
 301        let worktree_snapshots = project
 302            .read(cx)
 303            .worktrees(cx)
 304            .map(|worktree| worktree.read(cx).snapshot())
 305            .collect::<Vec<_>>();
 306        let debug_tx = self.debug_tx.clone();
 307
 308        let events = project_state
 309            .map(|state| {
 310                state
 311                    .events
 312                    .iter()
 313                    .map(|event| match event {
 314                        Event::BufferChange {
 315                            old_snapshot,
 316                            new_snapshot,
 317                            ..
 318                        } => {
 319                            let path = new_snapshot.file().map(|f| f.path().to_path_buf());
 320
 321                            let old_path = old_snapshot.file().and_then(|f| {
 322                                let old_path = f.path().as_ref();
 323                                if Some(old_path) != path.as_deref() {
 324                                    Some(old_path.to_path_buf())
 325                                } else {
 326                                    None
 327                                }
 328                            });
 329
 330                            predict_edits_v3::Event::BufferChange {
 331                                old_path,
 332                                path,
 333                                diff: language::unified_diff(
 334                                    &old_snapshot.text(),
 335                                    &new_snapshot.text(),
 336                                ),
 337                                //todo: Actually detect if this edit was predicted or not
 338                                predicted: false,
 339                            }
 340                        }
 341                    })
 342                    .collect::<Vec<_>>()
 343            })
 344            .unwrap_or_default();
 345
 346        let request_task = cx.background_spawn({
 347            let snapshot = snapshot.clone();
 348            let buffer = buffer.clone();
 349            async move {
 350                let index_state = if let Some(index_state) = index_state {
 351                    Some(index_state.lock_owned().await)
 352                } else {
 353                    None
 354                };
 355
 356                let cursor_point = position.to_point(&snapshot);
 357
 358                let before_retrieval = chrono::Utc::now();
 359
 360                let Some(context) = EditPredictionContext::gather_context(
 361                    cursor_point,
 362                    &snapshot,
 363                    &excerpt_options,
 364                    index_state.as_deref(),
 365                ) else {
 366                    return Ok(None);
 367                };
 368
 369                let debug_context = if let Some(debug_tx) = debug_tx {
 370                    Some((debug_tx, context.clone()))
 371                } else {
 372                    None
 373                };
 374
 375                let request = make_cloud_request(
 376                    excerpt_path.clone(),
 377                    context,
 378                    events,
 379                    // TODO data collection
 380                    false,
 381                    Vec::new(),
 382                    None,
 383                    debug_context.is_some(),
 384                    &worktree_snapshots,
 385                    index_state.as_deref(),
 386                );
 387
 388                let retrieval_time = chrono::Utc::now() - before_retrieval;
 389                let response = Self::perform_request(client, llm_token, app_version, request).await;
 390
 391                if let Some((debug_tx, context)) = debug_context {
 392                    debug_tx
 393                        .unbounded_send(response.as_ref().map_err(|err| err.to_string()).and_then(
 394                            |response| {
 395                                let Some(request) =
 396                                    some_or_debug_panic(response.0.debug_info.clone())
 397                                else {
 398                                    return Err("Missing debug info".to_string());
 399                                };
 400                                Ok(PredictionDebugInfo {
 401                                    context,
 402                                    request,
 403                                    retrieval_time,
 404                                    buffer: buffer.downgrade(),
 405                                    position,
 406                                })
 407                            },
 408                        ))
 409                        .ok();
 410                }
 411
 412                anyhow::Ok(Some(response?))
 413            }
 414        });
 415
 416        let buffer = buffer.clone();
 417
 418        cx.spawn(async move |this, cx| {
 419            match request_task.await {
 420                Ok(Some((response, usage))) => {
 421                    log::debug!("predicted edits: {:?}", &response.edits);
 422
 423                    if let Some(usage) = usage {
 424                        this.update(cx, |this, cx| {
 425                            this.user_store.update(cx, |user_store, cx| {
 426                                user_store.update_edit_prediction_usage(usage, cx);
 427                            });
 428                        })
 429                        .ok();
 430                    }
 431
 432                    // TODO telemetry: duration, etc
 433
 434                    // TODO produce smaller edits by diffing against snapshot first
 435                    //
 436                    // Cloud returns entire snippets/excerpts ranges as they were included
 437                    // in the request, but we should display smaller edits to the user.
 438                    //
 439                    // We can do this by computing a diff of each one against the snapshot.
 440                    // Similar to zeta::Zeta::compute_edits, but per edit.
 441                    let edits = response
 442                        .edits
 443                        .into_iter()
 444                        .map(|edit| {
 445                            // TODO edits to different files
 446                            (
 447                                snapshot.anchor_before(edit.range.start)
 448                                    ..snapshot.anchor_before(edit.range.end),
 449                                edit.content,
 450                            )
 451                        })
 452                        .collect::<Vec<_>>()
 453                        .into();
 454
 455                    let Some((edits, snapshot, edit_preview_task)) =
 456                        buffer.read_with(cx, |buffer, cx| {
 457                            let new_snapshot = buffer.snapshot();
 458                            let edits: Arc<[_]> =
 459                                interpolate(&snapshot, &new_snapshot, edits)?.into();
 460                            Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 461                        })?
 462                    else {
 463                        return Ok(None);
 464                    };
 465
 466                    Ok(Some(EditPrediction {
 467                        id: EditPredictionId(response.request_id),
 468                        edits,
 469                        snapshot,
 470                        edit_preview: edit_preview_task.await,
 471                    }))
 472                }
 473                Ok(None) => Ok(None),
 474                Err(err) => {
 475                    if err.is::<ZedUpdateRequiredError>() {
 476                        cx.update(|cx| {
 477                            this.update(cx, |this, _cx| {
 478                                this.update_required = true;
 479                            })
 480                            .ok();
 481
 482                            let error_message: SharedString = err.to_string().into();
 483                            show_app_notification(
 484                                NotificationId::unique::<ZedUpdateRequiredError>(),
 485                                cx,
 486                                move |cx| {
 487                                    cx.new(|cx| {
 488                                        ErrorMessagePrompt::new(error_message.clone(), cx)
 489                                            .with_link_button(
 490                                                "Update Zed",
 491                                                "https://zed.dev/releases",
 492                                            )
 493                                    })
 494                                },
 495                            );
 496                        })
 497                        .ok();
 498                    }
 499
 500                    Err(err)
 501                }
 502            }
 503        })
 504    }
 505
 506    async fn perform_request(
 507        client: Arc<Client>,
 508        llm_token: LlmApiToken,
 509        app_version: SemanticVersion,
 510        request: predict_edits_v3::PredictEditsRequest,
 511    ) -> Result<(
 512        predict_edits_v3::PredictEditsResponse,
 513        Option<EditPredictionUsage>,
 514    )> {
 515        let http_client = client.http_client();
 516        let mut token = llm_token.acquire(&client).await?;
 517        let mut did_retry = false;
 518
 519        loop {
 520            let request_builder = http_client::Request::builder().method(Method::POST);
 521            let request_builder =
 522                if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 523                    request_builder.uri(predict_edits_url)
 524                } else {
 525                    request_builder.uri(
 526                        http_client
 527                            .build_zed_llm_url("/predict_edits/v3", &[])?
 528                            .as_ref(),
 529                    )
 530                };
 531            let request = request_builder
 532                .header("Content-Type", "application/json")
 533                .header("Authorization", format!("Bearer {}", token))
 534                .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 535                .body(serde_json::to_string(&request)?.into())?;
 536
 537            let mut response = http_client.send(request).await?;
 538
 539            if let Some(minimum_required_version) = response
 540                .headers()
 541                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 542                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 543            {
 544                anyhow::ensure!(
 545                    app_version >= minimum_required_version,
 546                    ZedUpdateRequiredError {
 547                        minimum_version: minimum_required_version
 548                    }
 549                );
 550            }
 551
 552            if response.status().is_success() {
 553                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 554
 555                let mut body = Vec::new();
 556                response.body_mut().read_to_end(&mut body).await?;
 557                return Ok((serde_json::from_slice(&body)?, usage));
 558            } else if !did_retry
 559                && response
 560                    .headers()
 561                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 562                    .is_some()
 563            {
 564                did_retry = true;
 565                token = llm_token.refresh(&client).await?;
 566            } else {
 567                let mut body = String::new();
 568                response.body_mut().read_to_string(&mut body).await?;
 569                anyhow::bail!(
 570                    "error predicting edits.\nStatus: {:?}\nBody: {}",
 571                    response.status(),
 572                    body
 573                );
 574            }
 575        }
 576    }
 577
 578    // TODO: Dedupe with similar code in request_prediction?
 579    pub fn cloud_request_for_zeta_cli(
 580        &mut self,
 581        project: &Entity<Project>,
 582        buffer: &Entity<Buffer>,
 583        position: language::Anchor,
 584        cx: &mut Context<Self>,
 585    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
 586        let project_state = self.projects.get(&project.entity_id());
 587
 588        let index_state = project_state.map(|state| {
 589            state
 590                .syntax_index
 591                .read_with(cx, |index, _cx| index.state().clone())
 592        });
 593        let excerpt_options = self.excerpt_options.clone();
 594        let snapshot = buffer.read(cx).snapshot();
 595        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
 596            return Task::ready(Err(anyhow!("No file path for excerpt")));
 597        };
 598        let worktree_snapshots = project
 599            .read(cx)
 600            .worktrees(cx)
 601            .map(|worktree| worktree.read(cx).snapshot())
 602            .collect::<Vec<_>>();
 603
 604        cx.background_spawn(async move {
 605            let index_state = if let Some(index_state) = index_state {
 606                Some(index_state.lock_owned().await)
 607            } else {
 608                None
 609            };
 610
 611            let cursor_point = position.to_point(&snapshot);
 612
 613            let debug_info = true;
 614            EditPredictionContext::gather_context(
 615                cursor_point,
 616                &snapshot,
 617                &excerpt_options,
 618                index_state.as_deref(),
 619            )
 620            .context("Failed to select excerpt")
 621            .map(|context| {
 622                make_cloud_request(
 623                    excerpt_path.clone(),
 624                    context,
 625                    // TODO pass everything
 626                    Vec::new(),
 627                    false,
 628                    Vec::new(),
 629                    None,
 630                    debug_info,
 631                    &worktree_snapshots,
 632                    index_state.as_deref(),
 633                )
 634            })
 635        })
 636    }
 637}
 638
 639#[derive(Error, Debug)]
 640#[error(
 641    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
 642)]
 643pub struct ZedUpdateRequiredError {
 644    minimum_version: SemanticVersion,
 645}
 646
 647pub struct ZetaEditPredictionProvider {
 648    zeta: Entity<Zeta>,
 649    current_prediction: Option<CurrentEditPrediction>,
 650    next_pending_prediction_id: usize,
 651    pending_predictions: ArrayVec<PendingPrediction, 2>,
 652    last_request_timestamp: Instant,
 653}
 654
 655impl ZetaEditPredictionProvider {
 656    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
 657
 658    pub fn new(
 659        project: Option<&Entity<Project>>,
 660        client: &Arc<Client>,
 661        user_store: &Entity<UserStore>,
 662        cx: &mut App,
 663    ) -> Self {
 664        let zeta = Zeta::global(client, user_store, cx);
 665        if let Some(project) = project {
 666            zeta.update(cx, |zeta, cx| {
 667                zeta.register_project(project, cx);
 668            });
 669        }
 670
 671        Self {
 672            zeta,
 673            current_prediction: None,
 674            next_pending_prediction_id: 0,
 675            pending_predictions: ArrayVec::new(),
 676            last_request_timestamp: Instant::now(),
 677        }
 678    }
 679}
 680
 681#[derive(Clone)]
 682struct CurrentEditPrediction {
 683    buffer_id: EntityId,
 684    prediction: EditPrediction,
 685}
 686
 687impl CurrentEditPrediction {
 688    fn should_replace_prediction(&self, old_prediction: &Self, snapshot: &BufferSnapshot) -> bool {
 689        if self.buffer_id != old_prediction.buffer_id {
 690            return true;
 691        }
 692
 693        let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
 694            return true;
 695        };
 696        let Some(new_edits) = self.prediction.interpolate(snapshot) else {
 697            return false;
 698        };
 699
 700        if old_edits.len() == 1 && new_edits.len() == 1 {
 701            let (old_range, old_text) = &old_edits[0];
 702            let (new_range, new_text) = &new_edits[0];
 703            new_range == old_range && new_text.starts_with(old_text)
 704        } else {
 705            true
 706        }
 707    }
 708}
 709
 710#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
 711pub struct EditPredictionId(Uuid);
 712
 713impl From<EditPredictionId> for gpui::ElementId {
 714    fn from(value: EditPredictionId) -> Self {
 715        gpui::ElementId::Uuid(value.0)
 716    }
 717}
 718
 719impl std::fmt::Display for EditPredictionId {
 720    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 721        write!(f, "{}", self.0)
 722    }
 723}
 724
 725#[derive(Clone)]
 726pub struct EditPrediction {
 727    id: EditPredictionId,
 728    edits: Arc<[(Range<Anchor>, String)]>,
 729    snapshot: BufferSnapshot,
 730    edit_preview: EditPreview,
 731}
 732
 733impl EditPrediction {
 734    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
 735        interpolate(&self.snapshot, new_snapshot, self.edits.clone())
 736    }
 737}
 738
 739struct PendingPrediction {
 740    id: usize,
 741    _task: Task<()>,
 742}
 743
 744impl EditPredictionProvider for ZetaEditPredictionProvider {
 745    fn name() -> &'static str {
 746        "zed-predict2"
 747    }
 748
 749    fn display_name() -> &'static str {
 750        "Zed's Edit Predictions 2"
 751    }
 752
 753    fn show_completions_in_menu() -> bool {
 754        true
 755    }
 756
 757    fn show_tab_accept_marker() -> bool {
 758        true
 759    }
 760
 761    fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
 762        // TODO [zeta2]
 763        DataCollectionState::Unsupported
 764    }
 765
 766    fn toggle_data_collection(&mut self, _cx: &mut App) {
 767        // TODO [zeta2]
 768    }
 769
 770    fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
 771        self.zeta.read(cx).usage(cx)
 772    }
 773
 774    fn is_enabled(
 775        &self,
 776        _buffer: &Entity<language::Buffer>,
 777        _cursor_position: language::Anchor,
 778        _cx: &App,
 779    ) -> bool {
 780        true
 781    }
 782
 783    fn is_refreshing(&self) -> bool {
 784        !self.pending_predictions.is_empty()
 785    }
 786
 787    fn refresh(
 788        &mut self,
 789        project: Option<Entity<project::Project>>,
 790        buffer: Entity<language::Buffer>,
 791        cursor_position: language::Anchor,
 792        _debounce: bool,
 793        cx: &mut Context<Self>,
 794    ) {
 795        let Some(project) = project else {
 796            return;
 797        };
 798
 799        if self
 800            .zeta
 801            .read(cx)
 802            .user_store
 803            .read_with(cx, |user_store, _cx| {
 804                user_store.account_too_young() || user_store.has_overdue_invoices()
 805            })
 806        {
 807            return;
 808        }
 809
 810        if let Some(current_prediction) = self.current_prediction.as_ref() {
 811            let snapshot = buffer.read(cx).snapshot();
 812            if current_prediction
 813                .prediction
 814                .interpolate(&snapshot)
 815                .is_some()
 816            {
 817                return;
 818            }
 819        }
 820
 821        let pending_prediction_id = self.next_pending_prediction_id;
 822        self.next_pending_prediction_id += 1;
 823        let last_request_timestamp = self.last_request_timestamp;
 824
 825        let task = cx.spawn(async move |this, cx| {
 826            if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
 827                .checked_duration_since(Instant::now())
 828            {
 829                cx.background_executor().timer(timeout).await;
 830            }
 831
 832            let prediction_request = this.update(cx, |this, cx| {
 833                this.last_request_timestamp = Instant::now();
 834                this.zeta.update(cx, |zeta, cx| {
 835                    zeta.request_prediction(&project, &buffer, cursor_position, cx)
 836                })
 837            });
 838
 839            let prediction = match prediction_request {
 840                Ok(prediction_request) => {
 841                    let prediction_request = prediction_request.await;
 842                    prediction_request.map(|c| {
 843                        c.map(|prediction| CurrentEditPrediction {
 844                            buffer_id: buffer.entity_id(),
 845                            prediction,
 846                        })
 847                    })
 848                }
 849                Err(error) => Err(error),
 850            };
 851
 852            this.update(cx, |this, cx| {
 853                if this.pending_predictions[0].id == pending_prediction_id {
 854                    this.pending_predictions.remove(0);
 855                } else {
 856                    this.pending_predictions.clear();
 857                }
 858
 859                let Some(new_prediction) = prediction
 860                    .context("edit prediction failed")
 861                    .log_err()
 862                    .flatten()
 863                else {
 864                    cx.notify();
 865                    return;
 866                };
 867
 868                if let Some(old_prediction) = this.current_prediction.as_ref() {
 869                    let snapshot = buffer.read(cx).snapshot();
 870                    if new_prediction.should_replace_prediction(old_prediction, &snapshot) {
 871                        this.current_prediction = Some(new_prediction);
 872                    }
 873                } else {
 874                    this.current_prediction = Some(new_prediction);
 875                }
 876
 877                cx.notify();
 878            })
 879            .ok();
 880        });
 881
 882        // We always maintain at most two pending predictions. When we already
 883        // have two, we replace the newest one.
 884        if self.pending_predictions.len() <= 1 {
 885            self.pending_predictions.push(PendingPrediction {
 886                id: pending_prediction_id,
 887                _task: task,
 888            });
 889        } else if self.pending_predictions.len() == 2 {
 890            self.pending_predictions.pop();
 891            self.pending_predictions.push(PendingPrediction {
 892                id: pending_prediction_id,
 893                _task: task,
 894            });
 895        }
 896
 897        cx.notify();
 898    }
 899
 900    fn cycle(
 901        &mut self,
 902        _buffer: Entity<language::Buffer>,
 903        _cursor_position: language::Anchor,
 904        _direction: Direction,
 905        _cx: &mut Context<Self>,
 906    ) {
 907    }
 908
 909    fn accept(&mut self, _cx: &mut Context<Self>) {
 910        // TODO [zeta2] report accept
 911        self.current_prediction.take();
 912        self.pending_predictions.clear();
 913    }
 914
 915    fn discard(&mut self, _cx: &mut Context<Self>) {
 916        self.pending_predictions.clear();
 917        self.current_prediction.take();
 918    }
 919
 920    fn suggest(
 921        &mut self,
 922        buffer: &Entity<language::Buffer>,
 923        cursor_position: language::Anchor,
 924        cx: &mut Context<Self>,
 925    ) -> Option<edit_prediction::EditPrediction> {
 926        let CurrentEditPrediction {
 927            buffer_id,
 928            prediction,
 929            ..
 930        } = self.current_prediction.as_mut()?;
 931
 932        // Invalidate previous prediction if it was generated for a different buffer.
 933        if *buffer_id != buffer.entity_id() {
 934            self.current_prediction.take();
 935            return None;
 936        }
 937
 938        let buffer = buffer.read(cx);
 939        let Some(edits) = prediction.interpolate(&buffer.snapshot()) else {
 940            self.current_prediction.take();
 941            return None;
 942        };
 943
 944        let cursor_row = cursor_position.to_point(buffer).row;
 945        let (closest_edit_ix, (closest_edit_range, _)) =
 946            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
 947                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
 948                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
 949                cmp::min(distance_from_start, distance_from_end)
 950            })?;
 951
 952        let mut edit_start_ix = closest_edit_ix;
 953        for (range, _) in edits[..edit_start_ix].iter().rev() {
 954            let distance_from_closest_edit =
 955                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
 956            if distance_from_closest_edit <= 1 {
 957                edit_start_ix -= 1;
 958            } else {
 959                break;
 960            }
 961        }
 962
 963        let mut edit_end_ix = closest_edit_ix + 1;
 964        for (range, _) in &edits[edit_end_ix..] {
 965            let distance_from_closest_edit =
 966                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
 967            if distance_from_closest_edit <= 1 {
 968                edit_end_ix += 1;
 969            } else {
 970                break;
 971            }
 972        }
 973
 974        Some(edit_prediction::EditPrediction {
 975            id: Some(prediction.id.to_string().into()),
 976            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
 977            edit_preview: Some(prediction.edit_preview.clone()),
 978        })
 979    }
 980}
 981
 982fn make_cloud_request(
 983    excerpt_path: PathBuf,
 984    context: EditPredictionContext,
 985    events: Vec<predict_edits_v3::Event>,
 986    can_collect_data: bool,
 987    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
 988    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
 989    debug_info: bool,
 990    worktrees: &Vec<worktree::Snapshot>,
 991    index_state: Option<&SyntaxIndexState>,
 992) -> predict_edits_v3::PredictEditsRequest {
 993    let mut signatures = Vec::new();
 994    let mut declaration_to_signature_index = HashMap::default();
 995    let mut referenced_declarations = Vec::new();
 996
 997    for snippet in context.snippets {
 998        let project_entry_id = snippet.declaration.project_entry_id();
 999        let Some(path) = worktrees.iter().find_map(|worktree| {
1000            worktree.entry_for_id(project_entry_id).map(|entry| {
1001                let mut full_path = PathBuf::new();
1002                full_path.push(worktree.root_name());
1003                full_path.push(&entry.path);
1004                full_path
1005            })
1006        }) else {
1007            continue;
1008        };
1009
1010        let parent_index = index_state.and_then(|index_state| {
1011            snippet.declaration.parent().and_then(|parent| {
1012                add_signature(
1013                    parent,
1014                    &mut declaration_to_signature_index,
1015                    &mut signatures,
1016                    index_state,
1017                )
1018            })
1019        });
1020
1021        let (text, text_is_truncated) = snippet.declaration.item_text();
1022        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1023            path,
1024            text: text.into(),
1025            range: snippet.declaration.item_range(),
1026            text_is_truncated,
1027            signature_range: snippet.declaration.signature_range_in_item_text(),
1028            parent_index,
1029            score_components: snippet.score_components,
1030            signature_score: snippet.scores.signature,
1031            declaration_score: snippet.scores.declaration,
1032        });
1033    }
1034
1035    let excerpt_parent = index_state.and_then(|index_state| {
1036        context
1037            .excerpt
1038            .parent_declarations
1039            .last()
1040            .and_then(|(parent, _)| {
1041                add_signature(
1042                    *parent,
1043                    &mut declaration_to_signature_index,
1044                    &mut signatures,
1045                    index_state,
1046                )
1047            })
1048    });
1049
1050    predict_edits_v3::PredictEditsRequest {
1051        excerpt_path,
1052        excerpt: context.excerpt_text.body,
1053        excerpt_range: context.excerpt.range,
1054        cursor_offset: context.cursor_offset_in_excerpt,
1055        referenced_declarations,
1056        signatures,
1057        excerpt_parent,
1058        events,
1059        can_collect_data,
1060        diagnostic_groups,
1061        git_info,
1062        debug_info,
1063    }
1064}
1065
1066fn add_signature(
1067    declaration_id: DeclarationId,
1068    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1069    signatures: &mut Vec<Signature>,
1070    index: &SyntaxIndexState,
1071) -> Option<usize> {
1072    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1073        return Some(*signature_index);
1074    }
1075    let Some(parent_declaration) = index.declaration(declaration_id) else {
1076        log::error!("bug: missing parent declaration");
1077        return None;
1078    };
1079    let parent_index = parent_declaration.parent().and_then(|parent| {
1080        add_signature(parent, declaration_to_signature_index, signatures, index)
1081    });
1082    let (text, text_is_truncated) = parent_declaration.signature_text();
1083    let signature_index = signatures.len();
1084    signatures.push(Signature {
1085        text: text.into(),
1086        text_is_truncated,
1087        parent_index,
1088        range: parent_declaration.signature_range(),
1089    });
1090    declaration_to_signature_index.insert(declaration_id, signature_index);
1091    Some(signature_index)
1092}
1093
1094fn interpolate(
1095    old_snapshot: &BufferSnapshot,
1096    new_snapshot: &BufferSnapshot,
1097    current_edits: Arc<[(Range<Anchor>, String)]>,
1098) -> Option<Vec<(Range<Anchor>, String)>> {
1099    let mut edits = Vec::new();
1100
1101    let mut model_edits = current_edits.iter().peekable();
1102    for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
1103        while let Some((model_old_range, _)) = model_edits.peek() {
1104            let model_old_range = model_old_range.to_offset(old_snapshot);
1105            if model_old_range.end < user_edit.old.start {
1106                let (model_old_range, model_new_text) = model_edits.next().unwrap();
1107                edits.push((model_old_range.clone(), model_new_text.clone()));
1108            } else {
1109                break;
1110            }
1111        }
1112
1113        if let Some((model_old_range, model_new_text)) = model_edits.peek() {
1114            let model_old_offset_range = model_old_range.to_offset(old_snapshot);
1115            if user_edit.old == model_old_offset_range {
1116                let user_new_text = new_snapshot
1117                    .text_for_range(user_edit.new.clone())
1118                    .collect::<String>();
1119
1120                if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
1121                    if !model_suffix.is_empty() {
1122                        let anchor = old_snapshot.anchor_after(user_edit.old.end);
1123                        edits.push((anchor..anchor, model_suffix.to_string()));
1124                    }
1125
1126                    model_edits.next();
1127                    continue;
1128                }
1129            }
1130        }
1131
1132        return None;
1133    }
1134
1135    edits.extend(model_edits.cloned());
1136
1137    if edits.is_empty() { None } else { Some(edits) }
1138}
1139
1140#[cfg(test)]
1141mod tests {
1142    use super::*;
1143    use gpui::TestAppContext;
1144    use language::ToOffset as _;
1145
1146    #[gpui::test]
1147    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1148        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1149        let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1150            to_prediction_edits(
1151                [(2..5, "REM".to_string()), (9..11, "".to_string())],
1152                &buffer,
1153                cx,
1154            )
1155            .into()
1156        });
1157
1158        let edit_preview = cx
1159            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1160            .await;
1161
1162        let prediction = EditPrediction {
1163            id: EditPredictionId(Uuid::new_v4()),
1164            edits,
1165            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1166            edit_preview,
1167        };
1168
1169        cx.update(|cx| {
1170            assert_eq!(
1171                from_prediction_edits(
1172                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1173                    &buffer,
1174                    cx
1175                ),
1176                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1177            );
1178
1179            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1180            assert_eq!(
1181                from_prediction_edits(
1182                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1183                    &buffer,
1184                    cx
1185                ),
1186                vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1187            );
1188
1189            buffer.update(cx, |buffer, cx| buffer.undo(cx));
1190            assert_eq!(
1191                from_prediction_edits(
1192                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1193                    &buffer,
1194                    cx
1195                ),
1196                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1197            );
1198
1199            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1200            assert_eq!(
1201                from_prediction_edits(
1202                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1203                    &buffer,
1204                    cx
1205                ),
1206                vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1207            );
1208
1209            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1210            assert_eq!(
1211                from_prediction_edits(
1212                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1213                    &buffer,
1214                    cx
1215                ),
1216                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1217            );
1218
1219            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1220            assert_eq!(
1221                from_prediction_edits(
1222                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1223                    &buffer,
1224                    cx
1225                ),
1226                vec![(9..11, "".to_string())]
1227            );
1228
1229            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1230            assert_eq!(
1231                from_prediction_edits(
1232                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1233                    &buffer,
1234                    cx
1235                ),
1236                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1237            );
1238
1239            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1240            assert_eq!(
1241                from_prediction_edits(
1242                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1243                    &buffer,
1244                    cx
1245                ),
1246                vec![(4..4, "M".to_string())]
1247            );
1248
1249            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1250            assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1251        })
1252    }
1253
1254    fn to_prediction_edits(
1255        iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1256        buffer: &Entity<Buffer>,
1257        cx: &App,
1258    ) -> Vec<(Range<Anchor>, String)> {
1259        let buffer = buffer.read(cx);
1260        iterator
1261            .into_iter()
1262            .map(|(range, text)| {
1263                (
1264                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1265                    text,
1266                )
1267            })
1268            .collect()
1269    }
1270
1271    fn from_prediction_edits(
1272        editor_edits: &[(Range<Anchor>, String)],
1273        buffer: &Entity<Buffer>,
1274        cx: &App,
1275    ) -> Vec<(Range<usize>, String)> {
1276        let buffer = buffer.read(cx);
1277        editor_edits
1278            .iter()
1279            .map(|(range, text)| {
1280                (
1281                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
1282                    text.clone(),
1283                )
1284            })
1285            .collect()
1286    }
1287}