zeta2.rs

   1use anyhow::{Context as _, Result, anyhow};
   2use arrayvec::ArrayVec;
   3use chrono::TimeDelta;
   4use client::{Client, EditPredictionUsage, UserStore};
   5use cloud_llm_client::predict_edits_v3::{self, Signature};
   6use cloud_llm_client::{
   7    EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, ZED_VERSION_HEADER_NAME,
   8};
   9use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
  10use edit_prediction::{DataCollectionState, Direction, EditPredictionProvider};
  11use edit_prediction_context::{
  12    DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex,
  13    SyntaxIndexState,
  14};
  15use futures::AsyncReadExt as _;
  16use futures::channel::mpsc;
  17use gpui::http_client::Method;
  18use gpui::{
  19    App, Entity, EntityId, Global, SemanticVersion, SharedString, Subscription, Task, WeakEntity,
  20    http_client, prelude::*,
  21};
  22use language::{
  23    Anchor, Buffer, DiagnosticSet, LanguageServerId, OffsetRangeExt as _, ToOffset as _, ToPoint,
  24};
  25use language::{BufferSnapshot, EditPreview};
  26use language_model::{LlmApiToken, RefreshLlmTokenListener};
  27use project::Project;
  28use release_channel::AppVersion;
  29use std::cmp;
  30use std::collections::{HashMap, VecDeque, hash_map};
  31use std::path::PathBuf;
  32use std::str::FromStr as _;
  33use std::time::{Duration, Instant};
  34use std::{ops::Range, sync::Arc};
  35use thiserror::Error;
  36use util::{ResultExt as _, some_or_debug_panic};
  37use uuid::Uuid;
  38use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
  39
  40const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1);
  41
  42/// Maximum number of events to track.
  43const MAX_EVENT_COUNT: usize = 16;
  44
  45pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions {
  46    max_bytes: 512,
  47    min_bytes: 128,
  48    target_before_cursor_over_total_bytes: 0.5,
  49};
  50
  51pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
  52    excerpt: DEFAULT_EXCERPT_OPTIONS,
  53    max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
  54    max_diagnostic_bytes: 2048,
  55};
  56
  57#[derive(Clone)]
  58struct ZetaGlobal(Entity<Zeta>);
  59
  60impl Global for ZetaGlobal {}
  61
  62pub struct Zeta {
  63    client: Arc<Client>,
  64    user_store: Entity<UserStore>,
  65    llm_token: LlmApiToken,
  66    _llm_token_subscription: Subscription,
  67    projects: HashMap<EntityId, ZetaProject>,
  68    options: ZetaOptions,
  69    update_required: bool,
  70    debug_tx: Option<mpsc::UnboundedSender<Result<PredictionDebugInfo, String>>>,
  71}
  72
  73#[derive(Debug, Clone, PartialEq)]
  74pub struct ZetaOptions {
  75    pub excerpt: EditPredictionExcerptOptions,
  76    pub max_prompt_bytes: usize,
  77    pub max_diagnostic_bytes: usize,
  78}
  79
  80pub struct PredictionDebugInfo {
  81    pub context: EditPredictionContext,
  82    pub retrieval_time: TimeDelta,
  83    pub request: RequestDebugInfo,
  84    pub buffer: WeakEntity<Buffer>,
  85    pub position: language::Anchor,
  86}
  87
  88pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
  89
  90struct ZetaProject {
  91    syntax_index: Entity<SyntaxIndex>,
  92    events: VecDeque<Event>,
  93    registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
  94}
  95
  96struct RegisteredBuffer {
  97    snapshot: BufferSnapshot,
  98    _subscriptions: [gpui::Subscription; 2],
  99}
 100
 101#[derive(Clone)]
 102pub enum Event {
 103    BufferChange {
 104        old_snapshot: BufferSnapshot,
 105        new_snapshot: BufferSnapshot,
 106        timestamp: Instant,
 107    },
 108}
 109
 110impl Zeta {
 111    pub fn global(
 112        client: &Arc<Client>,
 113        user_store: &Entity<UserStore>,
 114        cx: &mut App,
 115    ) -> Entity<Self> {
 116        cx.try_global::<ZetaGlobal>()
 117            .map(|global| global.0.clone())
 118            .unwrap_or_else(|| {
 119                let zeta = cx.new(|cx| Self::new(client.clone(), user_store.clone(), cx));
 120                cx.set_global(ZetaGlobal(zeta.clone()));
 121                zeta
 122            })
 123    }
 124
 125    pub fn new(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut Context<Self>) -> Self {
 126        let refresh_llm_token_listener = RefreshLlmTokenListener::global(cx);
 127
 128        Self {
 129            projects: HashMap::new(),
 130            client,
 131            user_store,
 132            options: DEFAULT_OPTIONS,
 133            llm_token: LlmApiToken::default(),
 134            _llm_token_subscription: cx.subscribe(
 135                &refresh_llm_token_listener,
 136                |this, _listener, _event, cx| {
 137                    let client = this.client.clone();
 138                    let llm_token = this.llm_token.clone();
 139                    cx.spawn(async move |_this, _cx| {
 140                        llm_token.refresh(&client).await?;
 141                        anyhow::Ok(())
 142                    })
 143                    .detach_and_log_err(cx);
 144                },
 145            ),
 146            update_required: false,
 147            debug_tx: None,
 148        }
 149    }
 150
 151    pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<Result<PredictionDebugInfo, String>> {
 152        let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
 153        self.debug_tx = Some(debug_watch_tx);
 154        debug_watch_rx
 155    }
 156
 157    pub fn options(&self) -> &ZetaOptions {
 158        &self.options
 159    }
 160
 161    pub fn set_options(&mut self, options: ZetaOptions) {
 162        self.options = options;
 163    }
 164
 165    pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
 166        self.user_store.read(cx).edit_prediction_usage()
 167    }
 168
 169    pub fn register_project(&mut self, project: &Entity<Project>, cx: &mut App) {
 170        self.get_or_init_zeta_project(project, cx);
 171    }
 172
 173    pub fn register_buffer(
 174        &mut self,
 175        buffer: &Entity<Buffer>,
 176        project: &Entity<Project>,
 177        cx: &mut Context<Self>,
 178    ) {
 179        let zeta_project = self.get_or_init_zeta_project(project, cx);
 180        Self::register_buffer_impl(zeta_project, buffer, project, cx);
 181    }
 182
 183    fn get_or_init_zeta_project(
 184        &mut self,
 185        project: &Entity<Project>,
 186        cx: &mut App,
 187    ) -> &mut ZetaProject {
 188        self.projects
 189            .entry(project.entity_id())
 190            .or_insert_with(|| ZetaProject {
 191                syntax_index: cx.new(|cx| SyntaxIndex::new(project, cx)),
 192                events: VecDeque::new(),
 193                registered_buffers: HashMap::new(),
 194            })
 195    }
 196
 197    fn register_buffer_impl<'a>(
 198        zeta_project: &'a mut ZetaProject,
 199        buffer: &Entity<Buffer>,
 200        project: &Entity<Project>,
 201        cx: &mut Context<Self>,
 202    ) -> &'a mut RegisteredBuffer {
 203        let buffer_id = buffer.entity_id();
 204        match zeta_project.registered_buffers.entry(buffer_id) {
 205            hash_map::Entry::Occupied(entry) => entry.into_mut(),
 206            hash_map::Entry::Vacant(entry) => {
 207                let snapshot = buffer.read(cx).snapshot();
 208                let project_entity_id = project.entity_id();
 209                entry.insert(RegisteredBuffer {
 210                    snapshot,
 211                    _subscriptions: [
 212                        cx.subscribe(buffer, {
 213                            let project = project.downgrade();
 214                            move |this, buffer, event, cx| {
 215                                if let language::BufferEvent::Edited = event
 216                                    && let Some(project) = project.upgrade()
 217                                {
 218                                    this.report_changes_for_buffer(&buffer, &project, cx);
 219                                }
 220                            }
 221                        }),
 222                        cx.observe_release(buffer, move |this, _buffer, _cx| {
 223                            let Some(zeta_project) = this.projects.get_mut(&project_entity_id)
 224                            else {
 225                                return;
 226                            };
 227                            zeta_project.registered_buffers.remove(&buffer_id);
 228                        }),
 229                    ],
 230                })
 231            }
 232        }
 233    }
 234
 235    fn report_changes_for_buffer(
 236        &mut self,
 237        buffer: &Entity<Buffer>,
 238        project: &Entity<Project>,
 239        cx: &mut Context<Self>,
 240    ) -> BufferSnapshot {
 241        let zeta_project = self.get_or_init_zeta_project(project, cx);
 242        let registered_buffer = Self::register_buffer_impl(zeta_project, buffer, project, cx);
 243
 244        let new_snapshot = buffer.read(cx).snapshot();
 245        if new_snapshot.version != registered_buffer.snapshot.version {
 246            let old_snapshot =
 247                std::mem::replace(&mut registered_buffer.snapshot, new_snapshot.clone());
 248            Self::push_event(
 249                zeta_project,
 250                Event::BufferChange {
 251                    old_snapshot,
 252                    new_snapshot: new_snapshot.clone(),
 253                    timestamp: Instant::now(),
 254                },
 255            );
 256        }
 257
 258        new_snapshot
 259    }
 260
 261    fn push_event(zeta_project: &mut ZetaProject, event: Event) {
 262        let events = &mut zeta_project.events;
 263
 264        if let Some(Event::BufferChange {
 265            new_snapshot: last_new_snapshot,
 266            timestamp: last_timestamp,
 267            ..
 268        }) = events.back_mut()
 269        {
 270            // Coalesce edits for the same buffer when they happen one after the other.
 271            let Event::BufferChange {
 272                old_snapshot,
 273                new_snapshot,
 274                timestamp,
 275            } = &event;
 276
 277            if timestamp.duration_since(*last_timestamp) <= BUFFER_CHANGE_GROUPING_INTERVAL
 278                && old_snapshot.remote_id() == last_new_snapshot.remote_id()
 279                && old_snapshot.version == last_new_snapshot.version
 280            {
 281                *last_new_snapshot = new_snapshot.clone();
 282                *last_timestamp = *timestamp;
 283                return;
 284            }
 285        }
 286
 287        if events.len() >= MAX_EVENT_COUNT {
 288            // These are halved instead of popping to improve prompt caching.
 289            events.drain(..MAX_EVENT_COUNT / 2);
 290        }
 291
 292        events.push_back(event);
 293    }
 294
 295    pub fn request_prediction(
 296        &mut self,
 297        project: &Entity<Project>,
 298        buffer: &Entity<Buffer>,
 299        position: language::Anchor,
 300        cx: &mut Context<Self>,
 301    ) -> Task<Result<Option<EditPrediction>>> {
 302        let project_state = self.projects.get(&project.entity_id());
 303
 304        let index_state = project_state.map(|state| {
 305            state
 306                .syntax_index
 307                .read_with(cx, |index, _cx| index.state().clone())
 308        });
 309        let options = self.options.clone();
 310        let snapshot = buffer.read(cx).snapshot();
 311        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
 312            return Task::ready(Err(anyhow!("No file path for excerpt")));
 313        };
 314        let client = self.client.clone();
 315        let llm_token = self.llm_token.clone();
 316        let app_version = AppVersion::global(cx);
 317        let worktree_snapshots = project
 318            .read(cx)
 319            .worktrees(cx)
 320            .map(|worktree| worktree.read(cx).snapshot())
 321            .collect::<Vec<_>>();
 322        let debug_tx = self.debug_tx.clone();
 323
 324        let events = project_state
 325            .map(|state| {
 326                state
 327                    .events
 328                    .iter()
 329                    .map(|event| match event {
 330                        Event::BufferChange {
 331                            old_snapshot,
 332                            new_snapshot,
 333                            ..
 334                        } => {
 335                            let path = new_snapshot.file().map(|f| f.path().to_path_buf());
 336
 337                            let old_path = old_snapshot.file().and_then(|f| {
 338                                let old_path = f.path().as_ref();
 339                                if Some(old_path) != path.as_deref() {
 340                                    Some(old_path.to_path_buf())
 341                                } else {
 342                                    None
 343                                }
 344                            });
 345
 346                            predict_edits_v3::Event::BufferChange {
 347                                old_path,
 348                                path,
 349                                diff: language::unified_diff(
 350                                    &old_snapshot.text(),
 351                                    &new_snapshot.text(),
 352                                ),
 353                                //todo: Actually detect if this edit was predicted or not
 354                                predicted: false,
 355                            }
 356                        }
 357                    })
 358                    .collect::<Vec<_>>()
 359            })
 360            .unwrap_or_default();
 361
 362        let diagnostics = snapshot.diagnostic_sets().clone();
 363
 364        let request_task = cx.background_spawn({
 365            let snapshot = snapshot.clone();
 366            let buffer = buffer.clone();
 367            async move {
 368                let index_state = if let Some(index_state) = index_state {
 369                    Some(index_state.lock_owned().await)
 370                } else {
 371                    None
 372                };
 373
 374                let cursor_offset = position.to_offset(&snapshot);
 375                let cursor_point = cursor_offset.to_point(&snapshot);
 376
 377                let before_retrieval = chrono::Utc::now();
 378
 379                let Some(context) = EditPredictionContext::gather_context(
 380                    cursor_point,
 381                    &snapshot,
 382                    &options.excerpt,
 383                    index_state.as_deref(),
 384                ) else {
 385                    return Ok(None);
 386                };
 387
 388                let debug_context = if let Some(debug_tx) = debug_tx {
 389                    Some((debug_tx, context.clone()))
 390                } else {
 391                    None
 392                };
 393
 394                let (diagnostic_groups, diagnostic_groups_truncated) =
 395                    Self::gather_nearby_diagnostics(
 396                        cursor_offset,
 397                        &diagnostics,
 398                        &snapshot,
 399                        options.max_diagnostic_bytes,
 400                    );
 401
 402                let request = make_cloud_request(
 403                    excerpt_path.clone(),
 404                    context,
 405                    events,
 406                    // TODO data collection
 407                    false,
 408                    diagnostic_groups,
 409                    diagnostic_groups_truncated,
 410                    None,
 411                    debug_context.is_some(),
 412                    &worktree_snapshots,
 413                    index_state.as_deref(),
 414                    Some(options.max_prompt_bytes),
 415                );
 416
 417                let retrieval_time = chrono::Utc::now() - before_retrieval;
 418                let response = Self::perform_request(client, llm_token, app_version, request).await;
 419
 420                if let Some((debug_tx, context)) = debug_context {
 421                    debug_tx
 422                        .unbounded_send(response.as_ref().map_err(|err| err.to_string()).and_then(
 423                            |response| {
 424                                let Some(request) =
 425                                    some_or_debug_panic(response.0.debug_info.clone())
 426                                else {
 427                                    return Err("Missing debug info".to_string());
 428                                };
 429                                Ok(PredictionDebugInfo {
 430                                    context,
 431                                    request,
 432                                    retrieval_time,
 433                                    buffer: buffer.downgrade(),
 434                                    position,
 435                                })
 436                            },
 437                        ))
 438                        .ok();
 439                }
 440
 441                anyhow::Ok(Some(response?))
 442            }
 443        });
 444
 445        let buffer = buffer.clone();
 446
 447        cx.spawn(async move |this, cx| {
 448            match request_task.await {
 449                Ok(Some((response, usage))) => {
 450                    log::debug!("predicted edits: {:?}", &response.edits);
 451
 452                    if let Some(usage) = usage {
 453                        this.update(cx, |this, cx| {
 454                            this.user_store.update(cx, |user_store, cx| {
 455                                user_store.update_edit_prediction_usage(usage, cx);
 456                            });
 457                        })
 458                        .ok();
 459                    }
 460
 461                    // TODO telemetry: duration, etc
 462
 463                    // TODO produce smaller edits by diffing against snapshot first
 464                    //
 465                    // Cloud returns entire snippets/excerpts ranges as they were included
 466                    // in the request, but we should display smaller edits to the user.
 467                    //
 468                    // We can do this by computing a diff of each one against the snapshot.
 469                    // Similar to zeta::Zeta::compute_edits, but per edit.
 470                    let edits = response
 471                        .edits
 472                        .into_iter()
 473                        .map(|edit| {
 474                            // TODO edits to different files
 475                            (
 476                                snapshot.anchor_before(edit.range.start)
 477                                    ..snapshot.anchor_before(edit.range.end),
 478                                edit.content,
 479                            )
 480                        })
 481                        .collect::<Vec<_>>()
 482                        .into();
 483
 484                    let Some((edits, snapshot, edit_preview_task)) =
 485                        buffer.read_with(cx, |buffer, cx| {
 486                            let new_snapshot = buffer.snapshot();
 487                            let edits: Arc<[_]> =
 488                                interpolate(&snapshot, &new_snapshot, edits)?.into();
 489                            Some((edits.clone(), new_snapshot, buffer.preview_edits(edits, cx)))
 490                        })?
 491                    else {
 492                        return Ok(None);
 493                    };
 494
 495                    Ok(Some(EditPrediction {
 496                        id: EditPredictionId(response.request_id),
 497                        edits,
 498                        snapshot,
 499                        edit_preview: edit_preview_task.await,
 500                    }))
 501                }
 502                Ok(None) => Ok(None),
 503                Err(err) => {
 504                    if err.is::<ZedUpdateRequiredError>() {
 505                        cx.update(|cx| {
 506                            this.update(cx, |this, _cx| {
 507                                this.update_required = true;
 508                            })
 509                            .ok();
 510
 511                            let error_message: SharedString = err.to_string().into();
 512                            show_app_notification(
 513                                NotificationId::unique::<ZedUpdateRequiredError>(),
 514                                cx,
 515                                move |cx| {
 516                                    cx.new(|cx| {
 517                                        ErrorMessagePrompt::new(error_message.clone(), cx)
 518                                            .with_link_button(
 519                                                "Update Zed",
 520                                                "https://zed.dev/releases",
 521                                            )
 522                                    })
 523                                },
 524                            );
 525                        })
 526                        .ok();
 527                    }
 528
 529                    Err(err)
 530                }
 531            }
 532        })
 533    }
 534
 535    async fn perform_request(
 536        client: Arc<Client>,
 537        llm_token: LlmApiToken,
 538        app_version: SemanticVersion,
 539        request: predict_edits_v3::PredictEditsRequest,
 540    ) -> Result<(
 541        predict_edits_v3::PredictEditsResponse,
 542        Option<EditPredictionUsage>,
 543    )> {
 544        let http_client = client.http_client();
 545        let mut token = llm_token.acquire(&client).await?;
 546        let mut did_retry = false;
 547
 548        loop {
 549            let request_builder = http_client::Request::builder().method(Method::POST);
 550            let request_builder =
 551                if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
 552                    request_builder.uri(predict_edits_url)
 553                } else {
 554                    request_builder.uri(
 555                        http_client
 556                            .build_zed_llm_url("/predict_edits/v3", &[])?
 557                            .as_ref(),
 558                    )
 559                };
 560            let request = request_builder
 561                .header("Content-Type", "application/json")
 562                .header("Authorization", format!("Bearer {}", token))
 563                .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
 564                .body(serde_json::to_string(&request)?.into())?;
 565
 566            let mut response = http_client.send(request).await?;
 567
 568            if let Some(minimum_required_version) = response
 569                .headers()
 570                .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
 571                .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
 572            {
 573                anyhow::ensure!(
 574                    app_version >= minimum_required_version,
 575                    ZedUpdateRequiredError {
 576                        minimum_version: minimum_required_version
 577                    }
 578                );
 579            }
 580
 581            if response.status().is_success() {
 582                let usage = EditPredictionUsage::from_headers(response.headers()).ok();
 583
 584                let mut body = Vec::new();
 585                response.body_mut().read_to_end(&mut body).await?;
 586                return Ok((serde_json::from_slice(&body)?, usage));
 587            } else if !did_retry
 588                && response
 589                    .headers()
 590                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
 591                    .is_some()
 592            {
 593                did_retry = true;
 594                token = llm_token.refresh(&client).await?;
 595            } else {
 596                let mut body = String::new();
 597                response.body_mut().read_to_string(&mut body).await?;
 598                anyhow::bail!(
 599                    "error predicting edits.\nStatus: {:?}\nBody: {}",
 600                    response.status(),
 601                    body
 602                );
 603            }
 604        }
 605    }
 606
 607    fn gather_nearby_diagnostics(
 608        cursor_offset: usize,
 609        diagnostic_sets: &[(LanguageServerId, DiagnosticSet)],
 610        snapshot: &BufferSnapshot,
 611        max_diagnostics_bytes: usize,
 612    ) -> (Vec<predict_edits_v3::DiagnosticGroup>, bool) {
 613        // TODO: Could make this more efficient
 614        let mut diagnostic_groups = Vec::new();
 615        for (language_server_id, diagnostics) in diagnostic_sets {
 616            let mut groups = Vec::new();
 617            diagnostics.groups(*language_server_id, &mut groups, &snapshot);
 618            diagnostic_groups.extend(
 619                groups
 620                    .into_iter()
 621                    .map(|(_, group)| group.resolve::<usize>(&snapshot)),
 622            );
 623        }
 624
 625        // sort by proximity to cursor
 626        diagnostic_groups.sort_by_key(|group| {
 627            let range = &group.entries[group.primary_ix].range;
 628            if range.start >= cursor_offset {
 629                range.start - cursor_offset
 630            } else if cursor_offset >= range.end {
 631                cursor_offset - range.end
 632            } else {
 633                (cursor_offset - range.start).min(range.end - cursor_offset)
 634            }
 635        });
 636
 637        let mut results = Vec::new();
 638        let mut diagnostic_groups_truncated = false;
 639        let mut diagnostics_byte_count = 0;
 640        for group in diagnostic_groups {
 641            let raw_value = serde_json::value::to_raw_value(&group).unwrap();
 642            diagnostics_byte_count += raw_value.get().len();
 643            if diagnostics_byte_count > max_diagnostics_bytes {
 644                diagnostic_groups_truncated = true;
 645                break;
 646            }
 647            results.push(predict_edits_v3::DiagnosticGroup(raw_value));
 648        }
 649
 650        (results, diagnostic_groups_truncated)
 651    }
 652
 653    // TODO: Dedupe with similar code in request_prediction?
 654    pub fn cloud_request_for_zeta_cli(
 655        &mut self,
 656        project: &Entity<Project>,
 657        buffer: &Entity<Buffer>,
 658        position: language::Anchor,
 659        cx: &mut Context<Self>,
 660    ) -> Task<Result<predict_edits_v3::PredictEditsRequest>> {
 661        let project_state = self.projects.get(&project.entity_id());
 662
 663        let index_state = project_state.map(|state| {
 664            state
 665                .syntax_index
 666                .read_with(cx, |index, _cx| index.state().clone())
 667        });
 668        let options = self.options.clone();
 669        let snapshot = buffer.read(cx).snapshot();
 670        let Some(excerpt_path) = snapshot.file().map(|path| path.full_path(cx)) else {
 671            return Task::ready(Err(anyhow!("No file path for excerpt")));
 672        };
 673        let worktree_snapshots = project
 674            .read(cx)
 675            .worktrees(cx)
 676            .map(|worktree| worktree.read(cx).snapshot())
 677            .collect::<Vec<_>>();
 678
 679        cx.background_spawn(async move {
 680            let index_state = if let Some(index_state) = index_state {
 681                Some(index_state.lock_owned().await)
 682            } else {
 683                None
 684            };
 685
 686            let cursor_point = position.to_point(&snapshot);
 687
 688            let debug_info = true;
 689            EditPredictionContext::gather_context(
 690                cursor_point,
 691                &snapshot,
 692                &options.excerpt,
 693                index_state.as_deref(),
 694            )
 695            .context("Failed to select excerpt")
 696            .map(|context| {
 697                make_cloud_request(
 698                    excerpt_path.clone(),
 699                    context,
 700                    // TODO pass everything
 701                    Vec::new(),
 702                    false,
 703                    Vec::new(),
 704                    false,
 705                    None,
 706                    debug_info,
 707                    &worktree_snapshots,
 708                    index_state.as_deref(),
 709                    Some(options.max_prompt_bytes),
 710                )
 711            })
 712        })
 713    }
 714}
 715
 716#[derive(Error, Debug)]
 717#[error(
 718    "You must update to Zed version {minimum_version} or higher to continue using edit predictions."
 719)]
 720pub struct ZedUpdateRequiredError {
 721    minimum_version: SemanticVersion,
 722}
 723
 724pub struct ZetaEditPredictionProvider {
 725    zeta: Entity<Zeta>,
 726    current_prediction: Option<CurrentEditPrediction>,
 727    next_pending_prediction_id: usize,
 728    pending_predictions: ArrayVec<PendingPrediction, 2>,
 729    last_request_timestamp: Instant,
 730}
 731
 732impl ZetaEditPredictionProvider {
 733    pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
 734
 735    pub fn new(
 736        project: Option<&Entity<Project>>,
 737        client: &Arc<Client>,
 738        user_store: &Entity<UserStore>,
 739        cx: &mut App,
 740    ) -> Self {
 741        let zeta = Zeta::global(client, user_store, cx);
 742        if let Some(project) = project {
 743            zeta.update(cx, |zeta, cx| {
 744                zeta.register_project(project, cx);
 745            });
 746        }
 747
 748        Self {
 749            zeta,
 750            current_prediction: None,
 751            next_pending_prediction_id: 0,
 752            pending_predictions: ArrayVec::new(),
 753            last_request_timestamp: Instant::now(),
 754        }
 755    }
 756}
 757
 758#[derive(Clone)]
 759struct CurrentEditPrediction {
 760    buffer_id: EntityId,
 761    prediction: EditPrediction,
 762}
 763
 764impl CurrentEditPrediction {
 765    fn should_replace_prediction(&self, old_prediction: &Self, snapshot: &BufferSnapshot) -> bool {
 766        if self.buffer_id != old_prediction.buffer_id {
 767            return true;
 768        }
 769
 770        let Some(old_edits) = old_prediction.prediction.interpolate(snapshot) else {
 771            return true;
 772        };
 773        let Some(new_edits) = self.prediction.interpolate(snapshot) else {
 774            return false;
 775        };
 776
 777        if old_edits.len() == 1 && new_edits.len() == 1 {
 778            let (old_range, old_text) = &old_edits[0];
 779            let (new_range, new_text) = &new_edits[0];
 780            new_range == old_range && new_text.starts_with(old_text)
 781        } else {
 782            true
 783        }
 784    }
 785}
 786
 787#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
 788pub struct EditPredictionId(Uuid);
 789
 790impl From<EditPredictionId> for gpui::ElementId {
 791    fn from(value: EditPredictionId) -> Self {
 792        gpui::ElementId::Uuid(value.0)
 793    }
 794}
 795
 796impl std::fmt::Display for EditPredictionId {
 797    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 798        write!(f, "{}", self.0)
 799    }
 800}
 801
 802#[derive(Clone)]
 803pub struct EditPrediction {
 804    id: EditPredictionId,
 805    edits: Arc<[(Range<Anchor>, String)]>,
 806    snapshot: BufferSnapshot,
 807    edit_preview: EditPreview,
 808}
 809
 810impl EditPrediction {
 811    fn interpolate(&self, new_snapshot: &BufferSnapshot) -> Option<Vec<(Range<Anchor>, String)>> {
 812        interpolate(&self.snapshot, new_snapshot, self.edits.clone())
 813    }
 814}
 815
 816struct PendingPrediction {
 817    id: usize,
 818    _task: Task<()>,
 819}
 820
 821impl EditPredictionProvider for ZetaEditPredictionProvider {
 822    fn name() -> &'static str {
 823        "zed-predict2"
 824    }
 825
 826    fn display_name() -> &'static str {
 827        "Zed's Edit Predictions 2"
 828    }
 829
 830    fn show_completions_in_menu() -> bool {
 831        true
 832    }
 833
 834    fn show_tab_accept_marker() -> bool {
 835        true
 836    }
 837
 838    fn data_collection_state(&self, _cx: &App) -> DataCollectionState {
 839        // TODO [zeta2]
 840        DataCollectionState::Unsupported
 841    }
 842
 843    fn toggle_data_collection(&mut self, _cx: &mut App) {
 844        // TODO [zeta2]
 845    }
 846
 847    fn usage(&self, cx: &App) -> Option<client::EditPredictionUsage> {
 848        self.zeta.read(cx).usage(cx)
 849    }
 850
 851    fn is_enabled(
 852        &self,
 853        _buffer: &Entity<language::Buffer>,
 854        _cursor_position: language::Anchor,
 855        _cx: &App,
 856    ) -> bool {
 857        true
 858    }
 859
 860    fn is_refreshing(&self) -> bool {
 861        !self.pending_predictions.is_empty()
 862    }
 863
 864    fn refresh(
 865        &mut self,
 866        project: Option<Entity<project::Project>>,
 867        buffer: Entity<language::Buffer>,
 868        cursor_position: language::Anchor,
 869        _debounce: bool,
 870        cx: &mut Context<Self>,
 871    ) {
 872        let Some(project) = project else {
 873            return;
 874        };
 875
 876        if self
 877            .zeta
 878            .read(cx)
 879            .user_store
 880            .read_with(cx, |user_store, _cx| {
 881                user_store.account_too_young() || user_store.has_overdue_invoices()
 882            })
 883        {
 884            return;
 885        }
 886
 887        if let Some(current_prediction) = self.current_prediction.as_ref() {
 888            let snapshot = buffer.read(cx).snapshot();
 889            if current_prediction
 890                .prediction
 891                .interpolate(&snapshot)
 892                .is_some()
 893            {
 894                return;
 895            }
 896        }
 897
 898        let pending_prediction_id = self.next_pending_prediction_id;
 899        self.next_pending_prediction_id += 1;
 900        let last_request_timestamp = self.last_request_timestamp;
 901
 902        let task = cx.spawn(async move |this, cx| {
 903            if let Some(timeout) = (last_request_timestamp + Self::THROTTLE_TIMEOUT)
 904                .checked_duration_since(Instant::now())
 905            {
 906                cx.background_executor().timer(timeout).await;
 907            }
 908
 909            let prediction_request = this.update(cx, |this, cx| {
 910                this.last_request_timestamp = Instant::now();
 911                this.zeta.update(cx, |zeta, cx| {
 912                    zeta.request_prediction(&project, &buffer, cursor_position, cx)
 913                })
 914            });
 915
 916            let prediction = match prediction_request {
 917                Ok(prediction_request) => {
 918                    let prediction_request = prediction_request.await;
 919                    prediction_request.map(|c| {
 920                        c.map(|prediction| CurrentEditPrediction {
 921                            buffer_id: buffer.entity_id(),
 922                            prediction,
 923                        })
 924                    })
 925                }
 926                Err(error) => Err(error),
 927            };
 928
 929            this.update(cx, |this, cx| {
 930                if this.pending_predictions[0].id == pending_prediction_id {
 931                    this.pending_predictions.remove(0);
 932                } else {
 933                    this.pending_predictions.clear();
 934                }
 935
 936                let Some(new_prediction) = prediction
 937                    .context("edit prediction failed")
 938                    .log_err()
 939                    .flatten()
 940                else {
 941                    cx.notify();
 942                    return;
 943                };
 944
 945                if let Some(old_prediction) = this.current_prediction.as_ref() {
 946                    let snapshot = buffer.read(cx).snapshot();
 947                    if new_prediction.should_replace_prediction(old_prediction, &snapshot) {
 948                        this.current_prediction = Some(new_prediction);
 949                    }
 950                } else {
 951                    this.current_prediction = Some(new_prediction);
 952                }
 953
 954                cx.notify();
 955            })
 956            .ok();
 957        });
 958
 959        // We always maintain at most two pending predictions. When we already
 960        // have two, we replace the newest one.
 961        if self.pending_predictions.len() <= 1 {
 962            self.pending_predictions.push(PendingPrediction {
 963                id: pending_prediction_id,
 964                _task: task,
 965            });
 966        } else if self.pending_predictions.len() == 2 {
 967            self.pending_predictions.pop();
 968            self.pending_predictions.push(PendingPrediction {
 969                id: pending_prediction_id,
 970                _task: task,
 971            });
 972        }
 973
 974        cx.notify();
 975    }
 976
 977    fn cycle(
 978        &mut self,
 979        _buffer: Entity<language::Buffer>,
 980        _cursor_position: language::Anchor,
 981        _direction: Direction,
 982        _cx: &mut Context<Self>,
 983    ) {
 984    }
 985
 986    fn accept(&mut self, _cx: &mut Context<Self>) {
 987        // TODO [zeta2] report accept
 988        self.current_prediction.take();
 989        self.pending_predictions.clear();
 990    }
 991
 992    fn discard(&mut self, _cx: &mut Context<Self>) {
 993        self.pending_predictions.clear();
 994        self.current_prediction.take();
 995    }
 996
 997    fn suggest(
 998        &mut self,
 999        buffer: &Entity<language::Buffer>,
1000        cursor_position: language::Anchor,
1001        cx: &mut Context<Self>,
1002    ) -> Option<edit_prediction::EditPrediction> {
1003        let CurrentEditPrediction {
1004            buffer_id,
1005            prediction,
1006            ..
1007        } = self.current_prediction.as_mut()?;
1008
1009        // Invalidate previous prediction if it was generated for a different buffer.
1010        if *buffer_id != buffer.entity_id() {
1011            self.current_prediction.take();
1012            return None;
1013        }
1014
1015        let buffer = buffer.read(cx);
1016        let Some(edits) = prediction.interpolate(&buffer.snapshot()) else {
1017            self.current_prediction.take();
1018            return None;
1019        };
1020
1021        let cursor_row = cursor_position.to_point(buffer).row;
1022        let (closest_edit_ix, (closest_edit_range, _)) =
1023            edits.iter().enumerate().min_by_key(|(_, (range, _))| {
1024                let distance_from_start = cursor_row.abs_diff(range.start.to_point(buffer).row);
1025                let distance_from_end = cursor_row.abs_diff(range.end.to_point(buffer).row);
1026                cmp::min(distance_from_start, distance_from_end)
1027            })?;
1028
1029        let mut edit_start_ix = closest_edit_ix;
1030        for (range, _) in edits[..edit_start_ix].iter().rev() {
1031            let distance_from_closest_edit =
1032                closest_edit_range.start.to_point(buffer).row - range.end.to_point(buffer).row;
1033            if distance_from_closest_edit <= 1 {
1034                edit_start_ix -= 1;
1035            } else {
1036                break;
1037            }
1038        }
1039
1040        let mut edit_end_ix = closest_edit_ix + 1;
1041        for (range, _) in &edits[edit_end_ix..] {
1042            let distance_from_closest_edit =
1043                range.start.to_point(buffer).row - closest_edit_range.end.to_point(buffer).row;
1044            if distance_from_closest_edit <= 1 {
1045                edit_end_ix += 1;
1046            } else {
1047                break;
1048            }
1049        }
1050
1051        Some(edit_prediction::EditPrediction {
1052            id: Some(prediction.id.to_string().into()),
1053            edits: edits[edit_start_ix..edit_end_ix].to_vec(),
1054            edit_preview: Some(prediction.edit_preview.clone()),
1055        })
1056    }
1057}
1058
1059fn make_cloud_request(
1060    excerpt_path: PathBuf,
1061    context: EditPredictionContext,
1062    events: Vec<predict_edits_v3::Event>,
1063    can_collect_data: bool,
1064    diagnostic_groups: Vec<predict_edits_v3::DiagnosticGroup>,
1065    diagnostic_groups_truncated: bool,
1066    git_info: Option<cloud_llm_client::PredictEditsGitInfo>,
1067    debug_info: bool,
1068    worktrees: &Vec<worktree::Snapshot>,
1069    index_state: Option<&SyntaxIndexState>,
1070    prompt_max_bytes: Option<usize>,
1071) -> predict_edits_v3::PredictEditsRequest {
1072    let mut signatures = Vec::new();
1073    let mut declaration_to_signature_index = HashMap::default();
1074    let mut referenced_declarations = Vec::new();
1075
1076    for snippet in context.snippets {
1077        let project_entry_id = snippet.declaration.project_entry_id();
1078        let Some(path) = worktrees.iter().find_map(|worktree| {
1079            worktree.entry_for_id(project_entry_id).map(|entry| {
1080                let mut full_path = PathBuf::new();
1081                full_path.push(worktree.root_name());
1082                full_path.push(&entry.path);
1083                full_path
1084            })
1085        }) else {
1086            continue;
1087        };
1088
1089        let parent_index = index_state.and_then(|index_state| {
1090            snippet.declaration.parent().and_then(|parent| {
1091                add_signature(
1092                    parent,
1093                    &mut declaration_to_signature_index,
1094                    &mut signatures,
1095                    index_state,
1096                )
1097            })
1098        });
1099
1100        let (text, text_is_truncated) = snippet.declaration.item_text();
1101        referenced_declarations.push(predict_edits_v3::ReferencedDeclaration {
1102            path,
1103            text: text.into(),
1104            range: snippet.declaration.item_range(),
1105            text_is_truncated,
1106            signature_range: snippet.declaration.signature_range_in_item_text(),
1107            parent_index,
1108            score_components: snippet.score_components,
1109            signature_score: snippet.scores.signature,
1110            declaration_score: snippet.scores.declaration,
1111        });
1112    }
1113
1114    let excerpt_parent = index_state.and_then(|index_state| {
1115        context
1116            .excerpt
1117            .parent_declarations
1118            .last()
1119            .and_then(|(parent, _)| {
1120                add_signature(
1121                    *parent,
1122                    &mut declaration_to_signature_index,
1123                    &mut signatures,
1124                    index_state,
1125                )
1126            })
1127    });
1128
1129    predict_edits_v3::PredictEditsRequest {
1130        excerpt_path,
1131        excerpt: context.excerpt_text.body,
1132        excerpt_range: context.excerpt.range,
1133        cursor_offset: context.cursor_offset_in_excerpt,
1134        referenced_declarations,
1135        signatures,
1136        excerpt_parent,
1137        events,
1138        can_collect_data,
1139        diagnostic_groups,
1140        diagnostic_groups_truncated,
1141        git_info,
1142        debug_info,
1143        prompt_max_bytes,
1144    }
1145}
1146
1147fn add_signature(
1148    declaration_id: DeclarationId,
1149    declaration_to_signature_index: &mut HashMap<DeclarationId, usize>,
1150    signatures: &mut Vec<Signature>,
1151    index: &SyntaxIndexState,
1152) -> Option<usize> {
1153    if let Some(signature_index) = declaration_to_signature_index.get(&declaration_id) {
1154        return Some(*signature_index);
1155    }
1156    let Some(parent_declaration) = index.declaration(declaration_id) else {
1157        log::error!("bug: missing parent declaration");
1158        return None;
1159    };
1160    let parent_index = parent_declaration.parent().and_then(|parent| {
1161        add_signature(parent, declaration_to_signature_index, signatures, index)
1162    });
1163    let (text, text_is_truncated) = parent_declaration.signature_text();
1164    let signature_index = signatures.len();
1165    signatures.push(Signature {
1166        text: text.into(),
1167        text_is_truncated,
1168        parent_index,
1169        range: parent_declaration.signature_range(),
1170    });
1171    declaration_to_signature_index.insert(declaration_id, signature_index);
1172    Some(signature_index)
1173}
1174
1175fn interpolate(
1176    old_snapshot: &BufferSnapshot,
1177    new_snapshot: &BufferSnapshot,
1178    current_edits: Arc<[(Range<Anchor>, String)]>,
1179) -> Option<Vec<(Range<Anchor>, String)>> {
1180    let mut edits = Vec::new();
1181
1182    let mut model_edits = current_edits.iter().peekable();
1183    for user_edit in new_snapshot.edits_since::<usize>(&old_snapshot.version) {
1184        while let Some((model_old_range, _)) = model_edits.peek() {
1185            let model_old_range = model_old_range.to_offset(old_snapshot);
1186            if model_old_range.end < user_edit.old.start {
1187                let (model_old_range, model_new_text) = model_edits.next().unwrap();
1188                edits.push((model_old_range.clone(), model_new_text.clone()));
1189            } else {
1190                break;
1191            }
1192        }
1193
1194        if let Some((model_old_range, model_new_text)) = model_edits.peek() {
1195            let model_old_offset_range = model_old_range.to_offset(old_snapshot);
1196            if user_edit.old == model_old_offset_range {
1197                let user_new_text = new_snapshot
1198                    .text_for_range(user_edit.new.clone())
1199                    .collect::<String>();
1200
1201                if let Some(model_suffix) = model_new_text.strip_prefix(&user_new_text) {
1202                    if !model_suffix.is_empty() {
1203                        let anchor = old_snapshot.anchor_after(user_edit.old.end);
1204                        edits.push((anchor..anchor, model_suffix.to_string()));
1205                    }
1206
1207                    model_edits.next();
1208                    continue;
1209                }
1210            }
1211        }
1212
1213        return None;
1214    }
1215
1216    edits.extend(model_edits.cloned());
1217
1218    if edits.is_empty() { None } else { Some(edits) }
1219}
1220
1221#[cfg(test)]
1222mod tests {
1223    use super::*;
1224    use gpui::TestAppContext;
1225
1226    #[gpui::test]
1227    async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
1228        let buffer = cx.new(|cx| Buffer::local("Lorem ipsum dolor", cx));
1229        let edits: Arc<[(Range<Anchor>, String)]> = cx.update(|cx| {
1230            to_prediction_edits(
1231                [(2..5, "REM".to_string()), (9..11, "".to_string())],
1232                &buffer,
1233                cx,
1234            )
1235            .into()
1236        });
1237
1238        let edit_preview = cx
1239            .read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
1240            .await;
1241
1242        let prediction = EditPrediction {
1243            id: EditPredictionId(Uuid::new_v4()),
1244            edits,
1245            snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
1246            edit_preview,
1247        };
1248
1249        cx.update(|cx| {
1250            assert_eq!(
1251                from_prediction_edits(
1252                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1253                    &buffer,
1254                    cx
1255                ),
1256                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1257            );
1258
1259            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
1260            assert_eq!(
1261                from_prediction_edits(
1262                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1263                    &buffer,
1264                    cx
1265                ),
1266                vec![(2..2, "REM".to_string()), (6..8, "".to_string())]
1267            );
1268
1269            buffer.update(cx, |buffer, cx| buffer.undo(cx));
1270            assert_eq!(
1271                from_prediction_edits(
1272                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1273                    &buffer,
1274                    cx
1275                ),
1276                vec![(2..5, "REM".to_string()), (9..11, "".to_string())]
1277            );
1278
1279            buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
1280            assert_eq!(
1281                from_prediction_edits(
1282                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1283                    &buffer,
1284                    cx
1285                ),
1286                vec![(3..3, "EM".to_string()), (7..9, "".to_string())]
1287            );
1288
1289            buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
1290            assert_eq!(
1291                from_prediction_edits(
1292                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1293                    &buffer,
1294                    cx
1295                ),
1296                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1297            );
1298
1299            buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
1300            assert_eq!(
1301                from_prediction_edits(
1302                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1303                    &buffer,
1304                    cx
1305                ),
1306                vec![(9..11, "".to_string())]
1307            );
1308
1309            buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
1310            assert_eq!(
1311                from_prediction_edits(
1312                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1313                    &buffer,
1314                    cx
1315                ),
1316                vec![(4..4, "M".to_string()), (8..10, "".to_string())]
1317            );
1318
1319            buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
1320            assert_eq!(
1321                from_prediction_edits(
1322                    &prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
1323                    &buffer,
1324                    cx
1325                ),
1326                vec![(4..4, "M".to_string())]
1327            );
1328
1329            buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
1330            assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
1331        })
1332    }
1333
1334    fn to_prediction_edits(
1335        iterator: impl IntoIterator<Item = (Range<usize>, String)>,
1336        buffer: &Entity<Buffer>,
1337        cx: &App,
1338    ) -> Vec<(Range<Anchor>, String)> {
1339        let buffer = buffer.read(cx);
1340        iterator
1341            .into_iter()
1342            .map(|(range, text)| {
1343                (
1344                    buffer.anchor_after(range.start)..buffer.anchor_before(range.end),
1345                    text,
1346                )
1347            })
1348            .collect()
1349    }
1350
1351    fn from_prediction_edits(
1352        editor_edits: &[(Range<Anchor>, String)],
1353        buffer: &Entity<Buffer>,
1354        cx: &App,
1355    ) -> Vec<(Range<usize>, String)> {
1356        let buffer = buffer.read(cx);
1357        editor_edits
1358            .iter()
1359            .map(|(range, text)| {
1360                (
1361                    range.start.to_offset(buffer)..range.end.to_offset(buffer),
1362                    text.clone(),
1363                )
1364            })
1365            .collect()
1366    }
1367}