zeta2.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use arrayvec::ArrayVec;
   3use chrono::TimeDelta;
   4use client::{Client, EditPredictionUsage, UserStore};
   5use cloud_llm_client::predict_edits_v3::{self, Signature};
   6use cloud_llm_client::{
   7    EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
   8};
   9use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
  10use edit_prediction_context::{
  11    DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
  12    SyntaxIndexState,
  13};
  14use futures::AsyncReadExt as _;
  15use futures::channel::mpsc;
  16use gpui::http_client::Method;
  17use gpui::{
  18    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  19    http_client, prelude::*,
  20};
  21use language::{
  22    Anchor, Buffer, DiagnosticSet, LanguageServerId, OffsetRangeExt as _, ToOffset as _, ToPoint,
  23};
  24use language::{BufferSnapshot, EditPreview};
  25use language_model::{LlmApiToken, RefreshLlmTokenListener};
  26use project::Project;
  27use release_channel::AppVersion;
  28use std::cmp;
  29use std::collections::{HashMap, VecDeque, hash_map};
  30use std::path::PathBuf;
  31use std::str::FromStr as _;
  32use std::time::{Duration, Instant};
  33use std::{ops::Range, sync::Arc};
  34use thiserror::Error;
  35use util::{ResultExt as _, some_or_debug_panic};
  36use uuid::Uuid;
  37use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  38
  39const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  40
  41/// Maximum number of events to track.
  42const MAX_EVENT_COUNT: usize = 16;
  43
  44pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
  45    max_bytes: 512,
  46    min_bytes: 128,
  47    target_before_cursor_over_total_bytes: 0.5,
  48};
  49
  50pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  51    excerpt: DEFAULT_EXCERPT_OPTIONS,
  52    max_diagnostic_bytes: 2048,
  53};
  54
  55#[derive(Clone)]
  56struct ZetaGlobal(Entity<Zeta>);
  57
  58impl Global for ZetaGlobal {}
  59
  60pub struct Zeta {
  61    client: Arc<Client>,
  62    user_store: Entity<UserStore>,
  63    llm_token: LlmApiToken,
  64    _llm_token_subscription: Subscription,
  65    projects: HashMap<EntityId, ZetaProject>,
  66    options: ZetaOptions,
  67    update_required: bool,
  68    debug_tx: Option<mpsc::UnboundedSender<Result<PredictionDebugInfo, String>>>,
  69}
  70
  71#[derive(Debug, Clone, PartialEq)]
  72pub struct ZetaOptions {
  73    pub excerpt: EditPredictionExcerptOptions,
  74    pub max_diagnostic_bytes: usize,
  75}
  76
  77pub struct PredictionDebugInfo {
  78    pub context: EditPredictionContext,
  79    pub retrieval_time: TimeDelta,
  80    pub request: RequestDebugInfo,
  81    pub buffer: WeakEntity<Buffer>,
  82    pub position: language::Anchor,
  83}
  84
  85pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
  86
  87struct ZetaProject {
  88    syntax_index: Entity<SyntaxIndex>,
  89    events: VecDeque<Event>,
  90    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
  91}
  92
  93struct RegisteredBuffer {
  94    snapshot: BufferSnapshot,
  95    _subscriptions: [gpui::Subscription; 2],
  96}
  97
  98#[derive(Clone)]
  99pub enum Event {
 100    BufferChange {
 101        old_snapshot: BufferSnapshot,
 102        new_snapshot: BufferSnapshot,
 103        timestamp: Instant,
 104    },
 105}
 106
 107impl Zeta {
 108    pub fn global(
 109        client: &Arc<Client>,
 110        user_store: &Entity<UserStore>,
 111        cx: &mut App,
 112    ) -> Entity<Self> {
 113        cx.try_global::<ZetaGlobal>()
 114            .map(|global| global.0.clone())
 115            .unwrap_or_else(|| {
 116                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 117                cx.set_global(ZetaGlobal(zeta.clone()));
 118                zeta
 119            })
 120    }
 121
 122    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 123        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 124
 125        Self {
 126            projects: HashMap::new(),
 127            client,
 128            user_store,
 129            options: DEFAULT_OPTIONS,
 130            llm_token: LlmApiToken::default(),
 131            _llm_token_subscription: cx.subscribe(
 132                &refresh_llm_token_listener,
 133                |this, _listener, _event, cx| {
 134                    let client = this.client.clone();
 135                    let llm_token = this.llm_token.clone();
 136                    cx.spawn(async move |_this, _cx| {
 137                        llm_token.refresh(&client).await?;
 138                        anyhow::Ok(())
 139                    })
 140                    .detach_and_log_err(cx);
 141                },
 142            ),
 143            update_required: false,
 144            debug_tx: None,
 145        }
 146    }
 147
 148    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<Result<PredictionDebugInfo, String>> {
 149        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 150        self.debug_tx = Some(debug_watch_tx);
 151        debug_watch_rx
 152    }
 153
 154    pub fn options(&self) -> &ZetaOptions {
 155        &self.options
 156    }
 157
 158    pub fn set_options(&mut self, options: ZetaOptions) {
 159        self.options = options;
 160    }
 161
 162    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 163        self.user_store.read(cx).edit_prediction_usage()
 164    }
 165
 166    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 167        self.get_or_init_zeta_project(project, cx);
 168    }
 169
 170    pub fn register_buffer(
 171        &mut self,
 172        buffer: &Entity<Buffer>,
 173        project: &Entity<Project>,
 174        cx: &mut Context<Self>,
 175    ) {
 176        let zeta_project = self.get_or_init_zeta_project(project, cx);
 177        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 178    }
 179
 180    fn get_or_init_zeta_project(
 181        &mut self,
 182        project: &Entity<Project>,
 183        cx: &mut App,
 184    ) -> &mut ZetaProject {
 185        self.projects
 186            .entry(project.entity_id())
 187            .or_insert_with(|| ZetaProject {
 188                syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
 189                events: VecDeque::new(),
 190                registered_buffers: HashMap::new(),
 191            })
 192    }
 193
 194    fn register_buffer_impl<'a>(
 195        zeta_project: &'a mut ZetaProject,
 196        buffer: &Entity<Buffer>,
 197        project: &Entity<Project>,
 198        cx: &mut Context<Self>,
 199    ) -> &'a mut RegisteredBuffer {
 200        let buffer_id = buffer.entity_id();
 201        match zeta_project.registered_buffers.entry(buffer_id) {
 202            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 203            hash_map::Entry::Vacant(entry) => {
 204                let snapshot = buffer.read(cx).snapshot();
 205                let project_entity_id = project.entity_id();
 206                entry.insert(RegisteredBuffer {
 207                    snapshot,
 208                    _subscriptions: [
 209                        cx.subscribe(buffer, {
 210                            let project = project.downgrade();
 211                            move |this, buffer, event, cx| {
 212                                if let language::BufferEvent::Edited = event
 213                                    && let Some(project) = project.upgrade()
 214                                {
 215                                    this.report_changes_for_buffer(&buffer, &project, cx);
 216                                }
 217                            }
 218                        }),
 219                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 220                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 221                            else {
 222                                return;
 223                            };
 224                            zeta_project.registered_buffers.remove(&buffer_id);
 225                        }),
 226                    ],
 227                })
 228            }
 229        }
 230    }
 231
 232    fn report_changes_for_buffer(
 233        &mut self,
 234        buffer: &Entity<Buffer>,
 235        project: &Entity<Project>,
 236        cx: &mut Context<Self>,
 237    ) -> BufferSnapshot {
 238        let zeta_project = self.get_or_init_zeta_project(project, cx);
 239        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 240
 241        let new_snapshot = buffer.read(cx).snapshot();
 242        if new_snapshot.version != registered_buffer.snapshot.version {
 243            let old_snapshot =
 244                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 245            Self::push_event(
 246                zeta_project,
 247                Event::BufferChange {
 248                    old_snapshot,
 249                    new_snapshot: new_snapshot.clone(),
 250                    timestamp: Instant::now(),
 251                },
 252            );
 253        }
 254
 255        new_snapshot
 256    }
 257
 258    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
 259        let events = &mut zeta_project.events;
 260
 261        if let Some(Event::BufferChange {
 262            new_snapshot: last_new_snapshot,
 263            timestamp: last_timestamp,
 264            ..
 265        }) = events.back_mut()
 266        {
 267            // Coalesce edits for the same buffer when they happen one after the other.
 268            let Event::BufferChange {
 269                old_snapshot,
 270                new_snapshot,
 271                timestamp,
 272            } = &event;
 273
 274            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 275                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 276                && old_snapshot.version == last_new_snapshot.version
 277            {
 278                *last_new_snapshot = new_snapshot.clone();
 279                *last_timestamp = *timestamp;
 280                return;
 281            }
 282        }
 283
 284        if events.len() >= MAX_EVENT_COUNT {
 285            // These are halved instead of popping to improve prompt caching.
 286            events.drain(..MAX_EVENT_COUNT / 2);
 287        }
 288
 289        events.push_back(event);
 290    }
 291
 292    pub fn request_prediction(
 293        &mut self,
 294        project: &Entity<Project>,
 295        buffer: &Entity<Buffer>,
 296        position: language::Anchor,
 297        cx: &mut Context<Self>,
 298    ) -> Task<Result<Option<EditPrediction>>> {
 299        let project_state = self.projects.get(&project.entity_id());
 300
 301        let index_state = project_state.map(|state| {
 302            state
 303                .syntax_index
 304                .read_with(cx, |index, _cx| index.state().clone())
 305        });
 306        let options = self.options.clone();
 307        let snapshot = buffer.read(cx).snapshot();
 308        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
 309            return Task::ready(Err(anyhow!("No file path for excerpt")));
 310        };
 311        let client = self.client.clone();
 312        let llm_token = self.llm_token.clone();
 313        let app_version = AppVersion::global(cx);
 314        let worktree_snapshots = project
 315            .read(cx)
 316            .worktrees(cx)
 317            .map(|worktree| worktree.read(cx).snapshot())
 318            .collect::<Vec<_>>();
 319        let debug_tx = self.debug_tx.clone();
 320
 321        let events = project_state
 322            .map(|state| {
 323                state
 324                    .events
 325                    .iter()
 326                    .map(|event| match event {
 327                        Event::BufferChange {
 328                            old_snapshot,
 329                            new_snapshot,
 330                            ..
 331                        } => {
 332                            let path = new_snapshot.file().map(|f| f.path().to_path_buf());
 333
 334                            let old_path = old_snapshot.file().and_then(|f| {
 335                                let old_path = f.path().as_ref();
 336                                if Some(old_path) != path.as_deref() {
 337                                    Some(old_path.to_path_buf())
 338                                } else {
 339                                    None
 340                                }
 341                            });
 342
 343                            predict_edits_v3::Event::BufferChange {
 344                                old_path,
 345                                path,
 346                                diff: language::unified_diff(
 347                                    &old_snapshot.text(),
 348                                    &new_snapshot.text(),
 349                                ),
 350                                //todo: Actually detect if this edit was predicted or not
 351                                predicted: false,
 352                            }
 353                        }
 354                    })
 355                    .collect::<Vec<_>>()
 356            })
 357            .unwrap_or_default();
 358
 359        let diagnostics = snapshot.diagnostic_sets().clone();
 360
 361        let request_task = cx.background_spawn({
 362            let snapshot = snapshot.clone();
 363            let buffer = buffer.clone();
 364            async move {
 365                let index_state = if let Some(index_state) = index_state {
 366                    Some(index_state.lock_owned().await)
 367                } else {
 368                    None
 369                };
 370
 371                let cursor_offset = position.to_offset(&snapshot);
 372                let cursor_point = cursor_offset.to_point(&snapshot);
 373
 374                let before_retrieval = chrono::Utc::now();
 375
 376                let Some(context) = EditPredictionContext::gather_context(
 377                    cursor_point,
 378                    &snapshot,
 379                    &options.excerpt,
 380                    index_state.as_deref(),
 381                ) else {
 382                    return Ok(None);
 383                };
 384
 385                let debug_context = if let Some(debug_tx) = debug_tx {
 386                    Some((debug_tx, context.clone()))
 387                } else {
 388                    None
 389                };
 390
 391                let (diagnostic_groups, diagnostic_groups_truncated) =
 392                    Self::gather_nearby_diagnostics(
 393                        cursor_offset,
 394                        &diagnostics,
 395                        &snapshot,
 396                        options.max_diagnostic_bytes,
 397                    );
 398
 399                let request = make_cloud_request(
 400                    excerpt_path.clone(),
 401                    context,
 402                    events,
 403                    // TODO data collection
 404                    false,
 405                    diagnostic_groups,
 406                    diagnostic_groups_truncated,
 407                    None,
 408                    debug_context.is_some(),
 409                    &worktree_snapshots,
 410                    index_state.as_deref(),
 411                );
 412
 413                let retrieval_time = chrono::Utc::now() - before_retrieval;
 414                let response = Self::perform_request(client, llm_token, app_version, request).await;
 415
 416                if let Some((debug_tx, context)) = debug_context {
 417                    debug_tx
 418                        .unbounded_send(response.as_ref().map_err(|err| err.to_string()).and_then(
 419                            |response| {
 420                                let Some(request) =
 421                                    some_or_debug_panic(response.0.debug_info.clone())
 422                                else {
 423                                    return Err("Missing debug info".to_string());
 424                                };
 425                                Ok(PredictionDebugInfo {
 426                                    context,
 427                                    request,
 428                                    retrieval_time,
 429                                    buffer: buffer.downgrade(),
 430                                    position,
 431                                })
 432                            },
 433                        ))
 434                        .ok();
 435                }
 436
 437                anyhow::Ok(Some(response?))
 438            }
 439        });
 440
 441        let buffer = buffer.clone();
 442
 443        cx.spawn(async move |this, cx| {
 444            match request_task.await {
 445                Ok(Some((response, usage))) => {
 446                    log::debug!("predicted edits: {:?}", &response.edits);
 447
 448                    if let Some(usage) = usage {
 449                        this.update(cx, |this, cx| {
 450                            this.user_store.update(cx, |user_store, cx| {
 451                                user_store.update_edit_prediction_usage(usage, cx);
 452                            });
 453                        })
 454                        .ok();
 455                    }
 456
 457                    // TODO telemetry: duration, etc
 458
 459                    // TODO produce smaller edits by diffing against snapshot first
 460                    //
 461                    // Cloud returns entire snippets/excerpts ranges as they were included
 462                    // in the request, but we should display smaller edits to the user.
 463                    //
 464                    // We can do this by computing a diff of each one against the snapshot.
 465                    // Similar to zeta::Zeta::compute_edits, but per edit.
 466                    let edits = response
 467                        .edits
 468                        .into_iter()
 469                        .map(|edit| {
 470                            // TODO edits to different files
 471                            (
 472                                snapshot.anchor_before(edit.range.start)
 473                                    ..snapshot.anchor_before(edit.range.end),
 474                                edit.content,
 475                            )
 476                        })
 477                        .collect::<Vec<_>>()
 478                        .into();
 479
 480                    let Some((edits, snapshot, edit_preview_task)) =
 481                        buffer.read_with(cx, |buffer, cx| {
 482                            let new_snapshot = buffer.snapshot();
 483                            let edits: Arc<[_]> =
 484                                interpolate(&snapshot, &new_snapshot, edits)?.into();
 485                            Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 486                        })?
 487                    else {
 488                        return Ok(None);
 489                    };
 490
 491                    Ok(Some(EditPrediction {
 492                        id: EditPredictionId(response.request_id),
 493                        edits,
 494                        snapshot,
 495                        edit_preview: edit_preview_task.await,
 496                    }))
 497                }
 498                Ok(None) => Ok(None),
 499                Err(err) => {
 500                    if err.is::<ZedUpdateRequiredError>() {
 501                        cx.update(|cx| {
 502                            this.update(cx, |this, _cx| {
 503                                this.update_required = true;
 504                            })
 505                            .ok();
 506
 507                            let error_message: SharedString = err.to_string().into();
 508                            show_app_notification(
 509                                NotificationId::unique::<ZedUpdateRequiredError>(),
 510                                cx,
 511                                move |cx| {
 512                                    cx.new(|cx| {
 513                                        ErrorMessagePrompt::new(error_message.clone(), cx)
 514                                            .with_link_button(
 515                                                "Update Zed",
 516                                                "https://zed.dev/releases",
 517                                            )
 518                                    })
 519                                },
 520                            );
 521                        })
 522                        .ok();
 523                    }
 524
 525                    Err(err)
 526                }
 527            }
 528        })
 529    }
 530
 531    async fn perform_request(
 532        client: Arc<Client>,
 533        llm_token: LlmApiToken,
 534        app_version: SemanticVersion,
 535        request: predict_edits_v3::PredictEditsRequest,
 536    ) -> Result<(
 537        predict_edits_v3::PredictEditsResponse,
 538        Option<EditPredictionUsage>,
 539    )> {
 540        let http_client = client.http_client();
 541        let mut token = llm_token.acquire(&client).await?;
 542        let mut did_retry = false;
 543
 544        loop {
 545            let request_builder = http_client::Request::builder().method(Method::POST);
 546            let request_builder =
 547                if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 548                    request_builder.uri(predict_edits_url)
 549                } else {
 550                    request_builder.uri(
 551                        http_client
 552                            .build_zed_llm_url("/predict_edits/v3", &[])?
 553                            .as_ref(),
 554                    )
 555                };
 556            let request = request_builder
 557                .header("Content-Type", "application/json")
 558                .header("Authorization", format!("Bearer {}", token))
 559                .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 560                .body(serde_json::to_string(&request)?.into())?;
 561
 562            let mut response = http_client.send(request).await?;
 563
 564            if let Some(minimum_required_version) = response
 565                .headers()
 566                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 567                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 568            {
 569                anyhow::ensure!(
 570                    app_version >= minimum_required_version,
 571                    ZedUpdateRequiredError {
 572                        minimum_version: minimum_required_version
 573                    }
 574                );
 575            }
 576
 577            if response.status().is_success() {
 578                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 579
 580                let mut body = Vec::new();
 581                response.body_mut().read_to_end(&mut body).await?;
 582                return Ok((serde_json::from_slice(&body)?, usage));
 583            } else if !did_retry
 584                && response
 585                    .headers()
 586                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 587                    .is_some()
 588            {
 589                did_retry = true;
 590                token = llm_token.refresh(&client).await?;
 591            } else {
 592                let mut body = String::new();
 593                response.body_mut().read_to_string(&mut body).await?;
 594                anyhow::bail!(
 595                    "error predicting edits.\nStatus: {:?}\nBody: {}",
 596                    response.status(),
 597                    body
 598                );
 599            }
 600        }
 601    }
 602
 603    fn gather_nearby_diagnostics(
 604        cursor_offset: usize,
 605        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
 606        snapshot: &BufferSnapshot,
 607        max_diagnostics_bytes: usize,
 608    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
 609        // TODO: Could make this more efficient
 610        let mut diagnostic_groups = Vec::new();
 611        for (language_server_id, diagnostics) in diagnostic_sets {
 612            let mut groups = Vec::new();
 613            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
 614            diagnostic_groups.extend(
 615                groups
 616                    .into_iter()
 617                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
 618            );
 619        }
 620
 621        // sort by proximity to cursor
 622        diagnostic_groups.sort_by_key(|group| {
 623            let range = &group.entries[group.primary_ix].range;
 624            if range.start >= cursor_offset {
 625                range.start - cursor_offset
 626            } else if cursor_offset >= range.end {
 627                cursor_offset - range.end
 628            } else {
 629                (cursor_offset - range.start).min(range.end - cursor_offset)
 630            }
 631        });
 632
 633        let mut results = Vec::new();
 634        let mut diagnostic_groups_truncated = false;
 635        let mut diagnostics_byte_count = 0;
 636        for group in diagnostic_groups {
 637            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
 638            diagnostics_byte_count += raw_value.get().len();
 639            if diagnostics_byte_count > max_diagnostics_bytes {
 640                diagnostic_groups_truncated = true;
 641                break;
 642            }
 643            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
 644        }
 645
 646        (results, diagnostic_groups_truncated)
 647    }
 648
 649    // TODO: Dedupe with similar code in request_prediction?
 650    pub fn cloud_request_for_zeta_cli(
 651        &mut self,
 652        project: &Entity<Project>,
 653        buffer: &Entity<Buffer>,
 654        position: language::Anchor,
 655        cx: &mut Context<Self>,
 656    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
 657        let project_state = self.projects.get(&project.entity_id());
 658
 659        let index_state = project_state.map(|state| {
 660            state
 661                .syntax_index
 662                .read_with(cx, |index, _cx| index.state().clone())
 663        });
 664        let options = self.options.clone();
 665        let snapshot = buffer.read(cx).snapshot();
 666        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
 667            return Task::ready(Err(anyhow!("No file path for excerpt")));
 668        };
 669        let worktree_snapshots = project
 670            .read(cx)
 671            .worktrees(cx)
 672            .map(|worktree| worktree.read(cx).snapshot())
 673            .collect::<Vec<_>>();
 674
 675        cx.background_spawn(async move {
 676            let index_state = if let Some(index_state) = index_state {
 677                Some(index_state.lock_owned().await)
 678            } else {
 679                None
 680            };
 681
 682            let cursor_point = position.to_point(&snapshot);
 683
 684            let debug_info = true;
 685            EditPredictionContext::gather_context(
 686                cursor_point,
 687                &snapshot,
 688                &options.excerpt,
 689                index_state.as_deref(),
 690            )
 691            .context("Failed to select excerpt")
 692            .map(|context| {
 693                make_cloud_request(
 694                    excerpt_path.clone(),
 695                    context,
 696                    // TODO pass everything
 697                    Vec::new(),
 698                    false,
 699                    Vec::new(),
 700                    false,
 701                    None,
 702                    debug_info,
 703                    &worktree_snapshots,
 704                    index_state.as_deref(),
 705                )
 706            })
 707        })
 708    }
 709}
 710
 711#[derive(Error, Debug)]
 712#[error(
 713    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
 714)]
 715pub struct ZedUpdateRequiredError {
 716    minimum_version: SemanticVersion,
 717}
 718
 719pub struct ZetaEditPredictionProvider {
 720    zeta: Entity<Zeta>,
 721    current_prediction: Option<CurrentEditPrediction>,
 722    next_pending_prediction_id: usize,
 723    pending_predictions: ArrayVec<PendingPrediction, 2>,
 724    last_request_timestamp: Instant,
 725}
 726
 727impl ZetaEditPredictionProvider {
 728    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
 729
 730    pub fn new(
 731        project: Option<&Entity<Project>>,
 732        client: &Arc<Client>,
 733        user_store: &Entity<UserStore>,
 734        cx: &mut App,
 735    ) -> Self {
 736        let zeta = Zeta::global(client, user_store, cx);
 737        if let Some(project) = project {
 738            zeta.update(cx, |zeta, cx| {
 739                zeta.register_project(project, cx);
 740            });
 741        }
 742
 743        Self {
 744            zeta,
 745            current_prediction: None,
 746            next_pending_prediction_id: 0,
 747            pending_predictions: ArrayVec::new(),
 748            last_request_timestamp: Instant::now(),
 749        }
 750    }
 751}
 752
 753#[derive(Clone)]
 754struct CurrentEditPrediction {
 755    buffer_id: EntityId,
 756    prediction: EditPrediction,
 757}
 758
 759impl CurrentEditPrediction {
 760    fn should_replace_prediction(&self, old_prediction: &Self, snapshot: &BufferSnapshot) -> bool {
 761        if self.buffer_id != old_prediction.buffer_id {
 762            return true;
 763        }
 764
 765        let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
 766            return true;
 767        };
 768        let Some(new_edits) = self.prediction.interpolate(snapshot) else {
 769            return false;
 770        };
 771
 772        if old_edits.len() == 1 && new_edits.len() == 1 {
 773            let (old_range, old_text) = &old_edits[0];
 774            let (new_range, new_text) = &new_edits[0];
 775            new_range == old_range && new_text.starts_with(old_text)
 776        } else {
 777            true
 778        }
 779    }
 780}
 781
 782#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
 783pub struct EditPredictionId(Uuid);
 784
 785impl From<EditPredictionId> for gpui::ElementId {
 786    fn from(value: EditPredictionId) -> Self {
 787        gpui::ElementId::Uuid(value.0)
 788    }
 789}
 790
 791impl std::fmt::Display for EditPredictionId {
 792    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 793        write!(f, "{}", self.0)
 794    }
 795}
 796
 797#[derive(Clone)]
 798pub struct EditPrediction {
 799    id: EditPredictionId,
 800    edits: Arc<[(Range<Anchor>, String)]>,
 801    snapshot: BufferSnapshot,
 802    edit_preview: EditPreview,
 803}
 804
 805impl EditPrediction {
 806    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
 807        interpolate(&self.snapshot, new_snapshot, self.edits.clone())
 808    }
 809}
 810
 811struct PendingPrediction {
 812    id: usize,
 813    _task: Task<()>,
 814}
 815
 816impl EditPredictionProvider for ZetaEditPredictionProvider {
 817    fn name() -> &'static str {
 818        "zed-predict2"
 819    }
 820
 821    fn display_name() -> &'static str {
 822        "Zed's Edit Predictions 2"
 823    }
 824
 825    fn show_completions_in_menu() -> bool {
 826        true
 827    }
 828
 829    fn show_tab_accept_marker() -> bool {
 830        true
 831    }
 832
 833    fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
 834        // TODO [zeta2]
 835        DataCollectionState::Unsupported
 836    }
 837
 838    fn toggle_data_collection(&mut self, _cx: &mut App) {
 839        // TODO [zeta2]
 840    }
 841
 842    fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
 843        self.zeta.read(cx).usage(cx)
 844    }
 845
 846    fn is_enabled(
 847        &self,
 848        _buffer: &Entity<language::Buffer>,
 849        _cursor_position: language::Anchor,
 850        _cx: &App,
 851    ) -> bool {
 852        true
 853    }
 854
 855    fn is_refreshing(&self) -> bool {
 856        !self.pending_predictions.is_empty()
 857    }
 858
 859    fn refresh(
 860        &mut self,
 861        project: Option<Entity<project::Project>>,
 862        buffer: Entity<language::Buffer>,
 863        cursor_position: language::Anchor,
 864        _debounce: bool,
 865        cx: &mut Context<Self>,
 866    ) {
 867        let Some(project) = project else {
 868            return;
 869        };
 870
 871        if self
 872            .zeta
 873            .read(cx)
 874            .user_store
 875            .read_with(cx, |user_store, _cx| {
 876                user_store.account_too_young() || user_store.has_overdue_invoices()
 877            })
 878        {
 879            return;
 880        }
 881
 882        if let Some(current_prediction) = self.current_prediction.as_ref() {
 883            let snapshot = buffer.read(cx).snapshot();
 884            if current_prediction
 885                .prediction
 886                .interpolate(&snapshot)
 887                .is_some()
 888            {
 889                return;
 890            }
 891        }
 892
 893        let pending_prediction_id = self.next_pending_prediction_id;
 894        self.next_pending_prediction_id += 1;
 895        let last_request_timestamp = self.last_request_timestamp;
 896
 897        let task = cx.spawn(async move |this, cx| {
 898            if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
 899                .checked_duration_since(Instant::now())
 900            {
 901                cx.background_executor().timer(timeout).await;
 902            }
 903
 904            let prediction_request = this.update(cx, |this, cx| {
 905                this.last_request_timestamp = Instant::now();
 906                this.zeta.update(cx, |zeta, cx| {
 907                    zeta.request_prediction(&project, &buffer, cursor_position, cx)
 908                })
 909            });
 910
 911            let prediction = match prediction_request {
 912                Ok(prediction_request) => {
 913                    let prediction_request = prediction_request.await;
 914                    prediction_request.map(|c| {
 915                        c.map(|prediction| CurrentEditPrediction {
 916                            buffer_id: buffer.entity_id(),
 917                            prediction,
 918                        })
 919                    })
 920                }
 921                Err(error) => Err(error),
 922            };
 923
 924            this.update(cx, |this, cx| {
 925                if this.pending_predictions[0].id == pending_prediction_id {
 926                    this.pending_predictions.remove(0);
 927                } else {
 928                    this.pending_predictions.clear();
 929                }
 930
 931                let Some(new_prediction) = prediction
 932                    .context("edit prediction failed")
 933                    .log_err()
 934                    .flatten()
 935                else {
 936                    cx.notify();
 937                    return;
 938                };
 939
 940                if let Some(old_prediction) = this.current_prediction.as_ref() {
 941                    let snapshot = buffer.read(cx).snapshot();
 942                    if new_prediction.should_replace_prediction(old_prediction, &snapshot) {
 943                        this.current_prediction = Some(new_prediction);
 944                    }
 945                } else {
 946                    this.current_prediction = Some(new_prediction);
 947                }
 948
 949                cx.notify();
 950            })
 951            .ok();
 952        });
 953
 954        // We always maintain at most two pending predictions. When we already
 955        // have two, we replace the newest one.
 956        if self.pending_predictions.len() <= 1 {
 957            self.pending_predictions.push(PendingPrediction {
 958                id: pending_prediction_id,
 959                _task: task,
 960            });
 961        } else if self.pending_predictions.len() == 2 {
 962            self.pending_predictions.pop();
 963            self.pending_predictions.push(PendingPrediction {
 964                id: pending_prediction_id,
 965                _task: task,
 966            });
 967        }
 968
 969        cx.notify();
 970    }
 971
 972    fn cycle(
 973        &mut self,
 974        _buffer: Entity<language::Buffer>,
 975        _cursor_position: language::Anchor,
 976        _direction: Direction,
 977        _cx: &mut Context<Self>,
 978    ) {
 979    }
 980
 981    fn accept(&mut self, _cx: &mut Context<Self>) {
 982        // TODO [zeta2] report accept
 983        self.current_prediction.take();
 984        self.pending_predictions.clear();
 985    }
 986
 987    fn discard(&mut self, _cx: &mut Context<Self>) {
 988        self.pending_predictions.clear();
 989        self.current_prediction.take();
 990    }
 991
 992    fn suggest(
 993        &mut self,
 994        buffer: &Entity<language::Buffer>,
 995        cursor_position: language::Anchor,
 996        cx: &mut Context<Self>,
 997    ) -> Option<edit_prediction::EditPrediction> {
 998        let CurrentEditPrediction {
 999            buffer_id,
1000            prediction,
1001            ..
1002        } = self.current_prediction.as_mut()?;
1003
1004        // Invalidate previous prediction if it was generated for a different buffer.
1005        if *buffer_id != buffer.entity_id() {
1006            self.current_prediction.take();
1007            return None;
1008        }
1009
1010        let buffer = buffer.read(cx);
1011        let Some(edits) = prediction.interpolate(&buffer.snapshot()) else {
1012            self.current_prediction.take();
1013            return None;
1014        };
1015
1016        let cursor_row = cursor_position.to_point(buffer).row;
1017        let (closest_edit_ix, (closest_edit_range, _)) =
1018            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1019                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1020                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1021                cmp::min(distance_from_start, distance_from_end)
1022            })?;
1023
1024        let mut edit_start_ix = closest_edit_ix;
1025        for (range, _) in edits[..edit_start_ix].iter().rev() {
1026            let distance_from_closest_edit =
1027                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1028            if distance_from_closest_edit <= 1 {
1029                edit_start_ix -= 1;
1030            } else {
1031                break;
1032            }
1033        }
1034
1035        let mut edit_end_ix = closest_edit_ix + 1;
1036        for (range, _) in &edits[edit_end_ix..] {
1037            let distance_from_closest_edit =
1038                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1039            if distance_from_closest_edit <= 1 {
1040                edit_end_ix += 1;
1041            } else {
1042                break;
1043            }
1044        }
1045
1046        Some(edit_prediction::EditPrediction {
1047            id: Some(prediction.id.to_string().into()),
1048            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1049            edit_preview: Some(prediction.edit_preview.clone()),
1050        })
1051    }
1052}
1053
1054fn make_cloud_request(
1055    excerpt_path: PathBuf,
1056    context: EditPredictionContext,
1057    events: Vec<predict_edits_v3::Event>,
1058    can_collect_data: bool,
1059    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1060    diagnostic_groups_truncated: bool,
1061    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1062    debug_info: bool,
1063    worktrees: &Vec<worktree::Snapshot>,
1064    index_state: Option<&SyntaxIndexState>,
1065) -> predict_edits_v3::PredictEditsRequest {
1066    let mut signatures = Vec::new();
1067    let mut declaration_to_signature_index = HashMap::default();
1068    let mut referenced_declarations = Vec::new();
1069
1070    for snippet in context.snippets {
1071        let project_entry_id = snippet.declaration.project_entry_id();
1072        let Some(path) = worktrees.iter().find_map(|worktree| {
1073            worktree.entry_for_id(project_entry_id).map(|entry| {
1074                let mut full_path = PathBuf::new();
1075                full_path.push(worktree.root_name());
1076                full_path.push(&entry.path);
1077                full_path
1078            })
1079        }) else {
1080            continue;
1081        };
1082
1083        let parent_index = index_state.and_then(|index_state| {
1084            snippet.declaration.parent().and_then(|parent| {
1085                add_signature(
1086                    parent,
1087                    &mut declaration_to_signature_index,
1088                    &mut signatures,
1089                    index_state,
1090                )
1091            })
1092        });
1093
1094        let (text, text_is_truncated) = snippet.declaration.item_text();
1095        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1096            path,
1097            text: text.into(),
1098            range: snippet.declaration.item_range(),
1099            text_is_truncated,
1100            signature_range: snippet.declaration.signature_range_in_item_text(),
1101            parent_index,
1102            score_components: snippet.score_components,
1103            signature_score: snippet.scores.signature,
1104            declaration_score: snippet.scores.declaration,
1105        });
1106    }
1107
1108    let excerpt_parent = index_state.and_then(|index_state| {
1109        context
1110            .excerpt
1111            .parent_declarations
1112            .last()
1113            .and_then(|(parent, _)| {
1114                add_signature(
1115                    *parent,
1116                    &mut declaration_to_signature_index,
1117                    &mut signatures,
1118                    index_state,
1119                )
1120            })
1121    });
1122
1123    predict_edits_v3::PredictEditsRequest {
1124        excerpt_path,
1125        excerpt: context.excerpt_text.body,
1126        excerpt_range: context.excerpt.range,
1127        cursor_offset: context.cursor_offset_in_excerpt,
1128        referenced_declarations,
1129        signatures,
1130        excerpt_parent,
1131        events,
1132        can_collect_data,
1133        diagnostic_groups,
1134        diagnostic_groups_truncated,
1135
1136        git_info,
1137        debug_info,
1138    }
1139}
1140
1141fn add_signature(
1142    declaration_id: DeclarationId,
1143    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1144    signatures: &mut Vec<Signature>,
1145    index: &SyntaxIndexState,
1146) -> Option<usize> {
1147    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1148        return Some(*signature_index);
1149    }
1150    let Some(parent_declaration) = index.declaration(declaration_id) else {
1151        log::error!("bug: missing parent declaration");
1152        return None;
1153    };
1154    let parent_index = parent_declaration.parent().and_then(|parent| {
1155        add_signature(parent, declaration_to_signature_index, signatures, index)
1156    });
1157    let (text, text_is_truncated) = parent_declaration.signature_text();
1158    let signature_index = signatures.len();
1159    signatures.push(Signature {
1160        text: text.into(),
1161        text_is_truncated,
1162        parent_index,
1163        range: parent_declaration.signature_range(),
1164    });
1165    declaration_to_signature_index.insert(declaration_id, signature_index);
1166    Some(signature_index)
1167}
1168
1169fn interpolate(
1170    old_snapshot: &BufferSnapshot,
1171    new_snapshot: &BufferSnapshot,
1172    current_edits: Arc<[(Range<Anchor>, String)]>,
1173) -> Option<Vec<(Range<Anchor>, String)>> {
1174    let mut edits = Vec::new();
1175
1176    let mut model_edits = current_edits.iter().peekable();
1177    for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
1178        while let Some((model_old_range, _)) = model_edits.peek() {
1179            let model_old_range = model_old_range.to_offset(old_snapshot);
1180            if model_old_range.end < user_edit.old.start {
1181                let (model_old_range, model_new_text) = model_edits.next().unwrap();
1182                edits.push((model_old_range.clone(), model_new_text.clone()));
1183            } else {
1184                break;
1185            }
1186        }
1187
1188        if let Some((model_old_range, model_new_text)) = model_edits.peek() {
1189            let model_old_offset_range = model_old_range.to_offset(old_snapshot);
1190            if user_edit.old == model_old_offset_range {
1191                let user_new_text = new_snapshot
1192                    .text_for_range(user_edit.new.clone())
1193                    .collect::<String>();
1194
1195                if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
1196                    if !model_suffix.is_empty() {
1197                        let anchor = old_snapshot.anchor_after(user_edit.old.end);
1198                        edits.push((anchor..anchor, model_suffix.to_string()));
1199                    }
1200
1201                    model_edits.next();
1202                    continue;
1203                }
1204            }
1205        }
1206
1207        return None;
1208    }
1209
1210    edits.extend(model_edits.cloned());
1211
1212    if edits.is_empty() { None } else { Some(edits) }
1213}
1214
1215#[cfg(test)]
1216mod tests {
1217    use super::*;
1218    use gpui::TestAppContext;
1219
1220    #[gpui::test]
1221    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1222        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1223        let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1224            to_prediction_edits(
1225                [(2..5, "REM".to_string()), (9..11, "".to_string())],
1226                &buffer,
1227                cx,
1228            )
1229            .into()
1230        });
1231
1232        let edit_preview = cx
1233            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1234            .await;
1235
1236        let prediction = EditPrediction {
1237            id: EditPredictionId(Uuid::new_v4()),
1238            edits,
1239            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1240            edit_preview,
1241        };
1242
1243        cx.update(|cx| {
1244            assert_eq!(
1245                from_prediction_edits(
1246                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1247                    &buffer,
1248                    cx
1249                ),
1250                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1251            );
1252
1253            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1254            assert_eq!(
1255                from_prediction_edits(
1256                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1257                    &buffer,
1258                    cx
1259                ),
1260                vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1261            );
1262
1263            buffer.update(cx, |buffer, cx| buffer.undo(cx));
1264            assert_eq!(
1265                from_prediction_edits(
1266                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1267                    &buffer,
1268                    cx
1269                ),
1270                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1271            );
1272
1273            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1274            assert_eq!(
1275                from_prediction_edits(
1276                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1277                    &buffer,
1278                    cx
1279                ),
1280                vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1281            );
1282
1283            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1284            assert_eq!(
1285                from_prediction_edits(
1286                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1287                    &buffer,
1288                    cx
1289                ),
1290                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1291            );
1292
1293            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1294            assert_eq!(
1295                from_prediction_edits(
1296                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1297                    &buffer,
1298                    cx
1299                ),
1300                vec![(9..11, "".to_string())]
1301            );
1302
1303            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1304            assert_eq!(
1305                from_prediction_edits(
1306                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1307                    &buffer,
1308                    cx
1309                ),
1310                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1311            );
1312
1313            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1314            assert_eq!(
1315                from_prediction_edits(
1316                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1317                    &buffer,
1318                    cx
1319                ),
1320                vec![(4..4, "M".to_string())]
1321            );
1322
1323            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1324            assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1325        })
1326    }
1327
1328    fn to_prediction_edits(
1329        iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1330        buffer: &Entity<Buffer>,
1331        cx: &App,
1332    ) -> Vec<(Range<Anchor>, String)> {
1333        let buffer = buffer.read(cx);
1334        iterator
1335            .into_iter()
1336            .map(|(range, text)| {
1337                (
1338                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1339                    text,
1340                )
1341            })
1342            .collect()
1343    }
1344
1345    fn from_prediction_edits(
1346        editor_edits: &[(Range<Anchor>, String)],
1347        buffer: &Entity<Buffer>,
1348        cx: &App,
1349    ) -> Vec<(Range<usize>, String)> {
1350        let buffer = buffer.read(cx);
1351        editor_edits
1352            .iter()
1353            .map(|(range, text)| {
1354                (
1355                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
1356                    text.clone(),
1357                )
1358            })
1359            .collect()
1360    }
1361}